In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [None]:
batch_size = 32
training_epochs = 200
learning_rate = 1e-3

In [None]:
filename = 'splitted_dataset.npz'
dataset = np.load(filename)

In [None]:
inputs_train = dataset['inputs_train_1']
cleans_train = dataset['cleans_train_1']
inputs_test  = dataset['inputs_test_1']
cleans_test  = dataset['cleans_test_1']

In [None]:
print(inputs_train.shape)
print(cleans_train.shape)
print(inputs_test.shape)
print(cleans_test.shape)

In [None]:
class AutoEncoder_dataset(Dataset):
    def __init__(self, inputs, cleans):
        self.inputs = torch.from_numpy(inputs).float()
        self.cleans = torch.from_numpy(cleans).float()
    
    def __getitem__(self, index):
        inputs = self.inputs[index]
        cleans = self.cleans[index]
        return inputs, cleans
    
    def __len__(self):
        return len(self.inputs)

In [None]:
train_set = AutoEncoder_dataset(inputs = inputs_train,
                                cleans = cleans_train)
test_set  = AutoEncoder_dataset(inputs = inputs_test,
                                cleans = cleans_test)

In [None]:
train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True)
test_loader = DataLoader(dataset=test_set,
                         batch_size=batch_size,
                         shuffle=False,
                         drop_last=False)

In [None]:
class SegNet(nn.Module):
    
    def __init__(self):
        super(SegNet, self).__init__()
        
        # Encoder
        
        self.Enc_0_1 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True)
        )
        
        self.Enc_1_1 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True)
        )
        
        self.Enc_2_1 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels= 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )
        
        self.Enc_2_2 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels= 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )
        
        self.Enc_3_1 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels= 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        
        self.Enc_3_2 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels= 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        
        self.Enc_4_1 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels= 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        
        self.Enc_4_2 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels= 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        
        self.Pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        
        # Decoder
        
        self.Dec_4_2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        
        self.Dec_4_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        
        self.Dec_3_2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )
        
        self.Dec_3_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )
        
        self.Dec_2_2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True)
        )
        
        self.Dec_2_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True)
        )
        
        self.Dec_1_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True)
        )
        
        self.Dec_0_1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=2),
            nn.ReLU(inplace=True)
        )
        
        self.Unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        
        
    def forward(self, inputs):
        
        # Encoder
        
        dim_0 = inputs.size()
        outputs = self.Enc_0_1(inputs)
        outputs, indice_0 = self.Pool(outputs)
        
        dim_1 = outputs.size()
        outputs = self.Enc_1_1(outputs)
        outputs, indice_1 = self.Pool(outputs)
        
        dim_2 = outputs.size()
        outputs = self.Enc_2_1(outputs)
        outputs = self.Enc_2_2(outputs)
        outputs, indice_2 = self.Pool(outputs)
        
        dim_3 = outputs.size()
        outputs = self.Enc_3_1(outputs)
        outputs = self.Enc_3_2(outputs)
        outputs, indice_3 = self.Pool(outputs)
        
        dim_4 = outputs.size()
        outputs = self.Enc_4_1(outputs)
        outputs = self.Enc_4_2(outputs)
        outputs, indice_4 = self.Pool(outputs)
        
        dim_middle = outputs.size()
        
        # Decoder
        outputs = self.Unpool(outputs, indice_4, output_size=dim_4)
        outputs = self.Dec_4_2(outputs)
        outputs = self.Dec_4_1(outputs)
        dim_4d = outputs.size()
        
        outputs = self.Unpool(outputs, indice_3, output_size=dim_3)
        outputs = self.Dec_3_2(outputs)
        outputs = self.Dec_3_1(outputs)
        dim_3d = outputs.size()
        
        outputs = self.Unpool(outputs, indice_2, output_size=dim_2)
        outputs = self.Dec_2_2(outputs)
        outputs = self.Dec_2_1(outputs)
        dim_2d = outputs.size()
        
        outputs = self.Unpool(outputs, indice_1, output_size=dim_1)
        outputs = self.Dec_1_1(outputs)
        dim_1d = outputs.size()
        
        outputs = self.Unpool(outputs, indice_0, output_size=dim_0)
        outputs = self.Dec_0_1(outputs)
        dim_0d = outputs.size()
        
        return outputs
        

In [None]:
GPU_NUM = 6
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device)
print('Currernt cuda device ', torch.cuda.current_device())

if device.type == 'cuda':
    print(torch.cuda.get_device_name(GPU_NUM))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(GPU_NUM)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(GPU_NUM)/1024**3,1), 'GB')

In [None]:
net = SegNet().to(device)

In [None]:
PATH = 'SegNet.pt'

In [None]:
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
checkpoint = torch.load(PATH, map_location=device)
net.load_state_dict(checkpoint['State_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])


In [None]:
criterion = nn.MSELoss()

In [None]:
net.to(device)

In [None]:
def save_checkpoint(epoch, model, optimizer, filename):
    state = {
        'Epoch': epoch,
        'State_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(state, filename)

In [None]:
total_batch = len(train_loader)


for epoch in range(training_epochs):
    avg_cost = 0

    for i, (inputs, cleans) in enumerate(train_loader):
        inputs = inputs.to(device)
        cleans = cleans.to(device)

        optimizer.zero_grad()
        hypothesis = net(inputs)
        cost = criterion(hypothesis, cleans)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch
    if epoch % 10 == 9:
        print('[Epoch : {:>4} / {:>3}] cost = {:>.9}'.format(epoch + 1, training_epochs, avg_cost))
        
        save_checkpoint(epoch, net, optimizer, 'SegNet2.pt')
        print("======= Saved Model =======")
        
print('Learning finished')