In [1]:
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from unet_model import UNet
from torchinfo import summary
%matplotlib inline

torch.manual_seed(1337)
np.random.seed(1337)

In [2]:
class MyDataset(Dataset):

    def __init__(self, X_path="dataset/x_train.npy", y_path="dataset/y_train.npy", transform=None):
        self.X = np.load(X_path).transpose(0, 3, 1, 2)
        self.y = np.load(y_path)
        self.transform = transform
    
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if self.transform:
            sample = self.transform(self.X[idx])
            return sample, np.expand_dims(self.y[idx], 0)
        else:
            return self.X[idx], np.expand_dims(self.y[idx], 0)
        
data_transform = transforms.Compose([
    transforms.ToTensor()
])

In [3]:
train_dataset = MyDataset("dataset/x_train.npy", "dataset/y_train.npy")
val_dataset = MyDataset("dataset/x_val.npy", "dataset/y_val.npy")
test_dataset = MyDataset("dataset/x_test.npy", "dataset/y_test.npy")

#hyper params
batch_size = 1


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
device = "cuda"
model = UNet(3, 1, bilinear=False)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=4, verbose=True)

In [5]:
summary(model)

Layer (type:depth-idx)                        Param #
UNet                                          --
├─DoubleConv: 1-1                             --
│    └─Sequential: 2-1                        --
│    │    └─Conv2d: 3-1                       1,792
│    │    └─BatchNorm2d: 3-2                  128
│    │    └─ReLU: 3-3                         --
│    │    └─Conv2d: 3-4                       36,928
│    │    └─BatchNorm2d: 3-5                  128
│    │    └─ReLU: 3-6                         --
├─Down: 1-2                                   --
│    └─Sequential: 2-2                        --
│    │    └─MaxPool2d: 3-7                    --
│    │    └─DoubleConv: 3-8                   221,952
├─Down: 1-3                                   --
│    └─Sequential: 2-3                        --
│    │    └─MaxPool2d: 3-9                    --
│    │    └─DoubleConv: 3-10                  886,272
├─Down: 1-4                                   --
│    └─Sequential: 2-4                       

In [6]:
def train(epoch):
    train_loss = 0
    model.train()
    print(f"Running Epoch {epoch}")
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data.float())
        loss = F.l1_loss(output, target)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    train_loss /= len(train_loader.dataset)
    print(f"Epoch {epoch} : Avg Loss : {loss}")
    return train_loss
    
        
def validation():
    model.eval()
    validation_loss = 0
    for data, target in val_loader:
        data, target = data.to(device), target.to(device)
        output = model(data.float())
        validation_loss += F.l1_loss(output, target).item() # sum up batch loss

    validation_loss /= len(val_loader.dataset)
    print(f'Validation set: Average loss: {validation_loss}')
    return validation_loss

In [None]:
epochs = 20

train_loss = []
validation_loss = []
learning_rate = []

for epoch in range(1, epochs + 1):
    
    loss = train(epoch)
    train_loss.append(loss)
    
    loss = validation()
    validation_loss.append(loss)
    
    
    scheduler.step(loss)
    learning_rate.append(optimizer.param_groups[0]['lr'])
    model_file = 'models/model_' + str(epoch) + '.pth'
    torch.save(model.state_dict(), model_file)

Running Epoch 1


  0%|          | 0/1014 [00:00<?, ?it/s]