In [None]:
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from unet_model import UNet
from torchinfo import summary
%matplotlib inline

torch.manual_seed(1337)
np.random.seed(1337)

In [None]:
class MyDataset(Dataset):

    def __init__(self, X_path="dataset/x_train.npy", y_path="dataset/y_train.npy", transform_flag=False):
        self.X = np.load(X_path).transpose(0, 3, 1, 2)
        self.y = np.load(y_path)
        self.transform_flag = transform_flag
    
    def __len__(self):
        return self.X.shape[0]
    
    def transform(self, image, mask):
        # Random crop
        image = torch.tensor(image)
        mask = torch.tensor(mask)
        i, j, h, w = transforms.RandomCrop.get_params(image, output_size=(256,256))
        image = transforms.functional_tensor.crop(image, i, j, h, w)
        mask = transforms.functional_tensor.crop(mask, i, j, h, w)

        # Random horizontal flipping
        if np.random.rand() > 0.5:
            image = transforms.functional_tensor.hflip(image)
            mask = transforms.functional_tensor.hflip(mask)
        
        # Random brightness
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_brightness(image, np.random.rand() + 0.5)
        
        # Random Contrast
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_contrast(image, np.random.rand() + 0.5)
        
        # Random Gamma
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_gamma(image, np.random.rand() + 0.5)
            
        # Random Hue
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_hue(image, np.random.rand() - 0.5)
            
        # Random Saturation
        if np.random.rand() > 0.1:
            image = transforms.functional_tensor.adjust_saturation(image, np.random.rand() + 0.5)
            
        return image, mask

    def __getitem__(self, idx):
        if self.transform_flag:
            return self.transform(self.X[idx], np.expand_dims(self.y[idx], 0))
        else:
            return self.X[idx], np.expand_dims(self.y[idx], 0)

In [None]:
train_dataset = MyDataset("dataset/x_train.npy", "dataset/y_train.npy", transform_flag=True)
val_dataset = MyDataset("dataset/x_val.npy", "dataset/y_val.npy", transform_flag=False)
test_dataset = MyDataset("dataset/x_test.npy", "dataset/y_test.npy", transform_flag=False)

#hyper params
batch_size = 1


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
device = "cuda"
model = UNet(3, 1, bilinear=False)

model.load_state_dict(torch.load("models/Unet_l1/model_100.pth"))

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay = 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)

In [None]:
summary(model)

In [None]:
def logMSE(pred, groundtruth, lamda = 0.5):
    #as implemented in https://arxiv.org/pdf/1406.2283.pdf
    log_pred = torch.log(pred)
    log_gt = torch.log(groundtruth)
    d = log_pred - log_gt
    n = torch.numel(pred)
    first_term = torch.sum(d**2)/n
    second_term = torch.sum(d)**2 / n**2
    
    return first_term + lamda * second_term

def train(epoch):
    train_loss = 0
    model.train()
    print(f"Running Epoch {epoch}")
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data.float())
        loss = logMSE(output, target)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    train_loss /= len(train_loader.dataset)
    print(f"Epoch {epoch} : Avg Loss : {train_loss}")
    return train_loss
        
def validation():
    model.eval()
    validation_loss = 0
    for data, target in val_loader:
        data, target = data.to(device), target.to(device)
        output = model(data.float())
        validation_loss += logMSE(output, target).item() # sum up batch loss

    validation_loss /= len(val_loader.dataset)
    print(f'Validation set: Average loss: {validation_loss}')
    return validation_loss

In [None]:
epochs = 20

train_loss = []
validation_loss = []
learning_rate = []

for epoch in range(1, epochs + 1):
    
    loss = train(epoch)
    train_loss.append(loss)
    
    loss = validation()
    validation_loss.append(loss)
    
    
    scheduler.step(loss)
    learning_rate.append(optimizer.param_groups[0]['lr'])
    model_file = 'models/model_' + str(epoch) + '.pth'
    torch.save(model.state_dict(), model_file)

In [None]:
plt.figure(figsize=(10,5))
plt.plot(list(range(len(train_loss))), train_loss, label="Training loss")
plt.plot(list(range(len(validation_loss))), validation_loss, label="Validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(loc='best')
plt.show()

In [None]:
print(f"model with best validation loss is {np.argmin(validation_loss)+1}")