In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
from glob import glob
import time
import matplotlib.pyplot as plt
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import datetime

print(torch.cuda.is_available())

# Create a Dataloader

In [None]:
IMAGE_W, IMAGE_H = 512, 256

In [None]:
class ImageDataset(Dataset):
    def __init__(self, train_val_test=0):
        self.images = glob(f"./new_ds/{['train', 'val', 'test'][train_val_test]}/*.png")
        self.transform = torchvision.transforms.Compose([
            torchvision.transforms.Normalize(mean=[0.28689554, 0.32513303, 0.28389177], 
                                             std=[0.18696375, 0.19017339, 0.18720214])
        ])
        
    def __len__(self):
        return len(self.images)       
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        # load and transform image
        img_pil = self.transform(torchvision.io.read_image(img_path) / 255.)
        return img_pil

In [None]:
img_ds = ImageDataset()
test_img = img_ds[4]

print(test_img.mean())
print(test_img.std())

In [None]:
# print(len(img_ds))

# t1 = time.time()

# for img in img_ds:
#     y = img
    
# print(f"Total time to go through ds: {time.time() - t1:.2f}s")

In [None]:
def visualize(x):
    device = 'cuda' if x.is_cuda else 'cpu'
    
    x = x.permute(2, 1, 0)
    x *= torch.tensor([0.18696375, 0.19017339, 0.18720214]).to(device)
    x += torch.tensor([0.28689554, 0.32513303, 0.28389177]).to(device)
    return x.permute(2, 1, 0)

In [None]:
test_img = visualize(img_ds[4]).unsqueeze(0)

print("min:", test_img.min())
print("max:", test_img.max())

plt.figure(figsize=(15, 7))
plt.imshow(test_img.squeeze(0).permute(1, 2, 0))

# Create the model

In [8]:
# CURRENT: https://medium.com/@tioluwaniaremu/vgg-16-a-simple-implementation-using-pytorch-7850be4d14a1

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # 3, 256, 512
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        # 64, 128, 256
        self.conv2_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        # 128, 64, 128
        self.conv3_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        # 256, 32, 64
        self.conv4_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        # 512, 16, 32
        self.conv5_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        # 512, 8, 16
        self.conv6_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv6_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv6_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        # 512, 4, 8        
        self.convup6_3 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)
        self.convup6_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.convup6_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        # 512, 8, 16
        self.convup5_3 = nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)
        self.convup5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.convup5_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
        # 512, 16, 32
        self.convup4_3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)
        self.convup4_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.convup4_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        # 256, 32, 64
        self.convup3_3 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)
        self.convup3_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.convup3_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        # 128, 64, 128
        self.convup2_3 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
        self.convup2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.convup2_1 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
        # 64, 128, 256
        self.convup1_3 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2)
        self.convup1_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        self.convup1_1 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        # 64, 256, 512
        self.convup0_1 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
        # 3, 256, 512
        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        # encoder
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv6_1(x))
        x = F.relu(self.conv6_2(x))
        x = F.relu(self.conv6_3(x))
        x = self.maxpool(x)
        # decoder
        x = F.relu(self.convup6_3(x))
        x = F.relu(self.convup6_2(x))
        x = F.relu(self.convup6_1(x))
        x = F.relu(self.convup5_3(x))
        x = F.relu(self.convup5_2(x))
        x = F.relu(self.convup5_1(x))
        x = F.relu(self.convup4_3(x))
        x = F.relu(self.convup4_2(x))
        x = F.relu(self.convup4_1(x))
        x = F.relu(self.convup3_3(x))
        x = F.relu(self.convup3_2(x))
        x = F.relu(self.convup3_1(x))
        x = F.relu(self.convup2_3(x))
        x = F.relu(self.convup2_2(x))
        x = F.relu(self.convup2_1(x))
        x = F.relu(self.convup1_3(x))
        x = F.relu(self.convup1_2(x))
        x = F.relu(self.convup1_1(x))
        x = F.relu(self.convup0_1(x))
        
        return x

print(f"Original\t{test_img.shape}")
out = Model()(test_img)
print(f"Model out\t{out.shape}")

Original	torch.Size([1, 3, 256, 512])
Model out	torch.Size([1, 3, 256, 512])


In [9]:
summary(Model().cuda(), input_size=(3, 256, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 512]           1,792
            Conv2d-2         [-1, 64, 256, 512]          36,928
         MaxPool2d-3         [-1, 64, 128, 256]               0
            Conv2d-4        [-1, 128, 128, 256]          73,856
            Conv2d-5        [-1, 128, 128, 256]         147,584
         MaxPool2d-6         [-1, 128, 64, 128]               0
            Conv2d-7         [-1, 256, 64, 128]         295,168
            Conv2d-8         [-1, 256, 64, 128]         590,080
            Conv2d-9         [-1, 256, 64, 128]         590,080
        MaxPool2d-10          [-1, 256, 32, 64]               0
           Conv2d-11          [-1, 512, 32, 64]       1,180,160
           Conv2d-12          [-1, 512, 32, 64]       2,359,808
           Conv2d-13          [-1, 512, 32, 64]       2,359,808
        MaxPool2d-14          [-1, 512,

# Training Loop

In [13]:
BATCH_SIZE = 8
EPOCHS = 1000
LEARNING_RATE = 0.001

In [14]:
writer = SummaryWriter('runs/vgg16' + '_' + datetime.datetime.now().strftime("%d%m-%H%M%S"))

# intialize model
m = Model().cuda()

# initialize optimizer 
optim = torch.optim.Adam(m.parameters(), lr=LEARNING_RATE)

# dataloader 
dl_train = torch.utils.data.DataLoader(ImageDataset(train_val_test=0), batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# set up validation datasets and validation images
val_ds = ImageDataset(train_val_test=1)
dl_val = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

val_img1, val_img2, val_img3 = val_ds[14], val_ds[307], val_ds[450]

grid = torch.zeros((3, 256 * 2, 512 * 3))
grid[:, :256, 0:512] = visualize(val_img1)
grid[:, :256, 512:512*2] = visualize(val_img2)
grid[:, :256, 512*2:512*3] = visualize(val_img3)

val_set = torch.stack((val_img1, val_img2, val_img3))

# loss function
criterion = torch.nn.MSELoss()

best_loss, patience = 10000, 0

for epoch in range(EPOCHS):
    loss_sum = 0
    t1 = time.time()
    for i, batch in enumerate(dl_train):
        # zero the gradient
        optim.zero_grad()
        # put the batch on GPU
        batch = batch.cuda()
        # forward pass
        out = m(batch)
        # calculate loss
        loss = criterion(out, batch)
        # backward pass
        loss.backward()
        # update the weights
        optim.step()
        # add loss to loss sum
        loss_sum += loss.item()
    
    with torch.no_grad():
        # validation and metric logging
        train_loss = loss_sum / len(dl_train)
        val_loss = 0
        
        # calculate validation loss
        for batch in dl_val:
            batch = batch.cuda()
            # forward pass
            out = m(batch)
            # calculate loss
            loss = criterion(out, batch)
            # add loss to loss sum
            val_loss += loss.item()
        
        # average the validation loss
        val_loss /= len(dl_val)
        # create validation images
        val_out = m(val_set.cuda())
    
        grid[:, 256:, 0:512] = torch.clamp(visualize(val_out[0]), 0, 1)
        grid[:, 256:, 512:512*2] = torch.clamp(visualize(val_out[1]), 0, 1)
        grid[:, 256:, 512*2:512*3] = torch.clamp(visualize(val_out[2]), 0, 1)

        writer.add_image('images', grid, epoch)
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Val/train', val_loss, epoch)

        if val_loss < best_loss:
            best_loss = val_loss
            patience = 0
        else:
            patience += 1
            
        if patience == 10:
            print("No new best model achieved, stopping here.")
            break
        
        print(f"[{epoch + 1}/{EPOCHS}]\t({time.time() - t1:.2f}s)\tloss: {train_loss:.4f}\tval_loss: {val_loss:.4f}\tpatience: {patience}")

RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 11.00 GiB total capacity; 10.15 GiB already allocated; 0 bytes free; 10.21 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
plt.imshow(torch.clamp(visualize(m(img_ds[0].unsqueeze(0).cuda()).squeeze(0).detach().cpu()).permute(1, 2, 0), 0, 1))