In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Resize
import torch.nn as N
import matplotlib.pyplot as plt

In [None]:
torch.manual_seed(42)

batch_size = 32
epochs = 5
learning_rate = 1e-3

filters = [32, 64, 128, 256]
drop_rate= 0.15

img_depth = 3

In [None]:
class OxfordPetDataset(Dataset): 
    def __init__(self, split= 'trainval', transforms= []) -> None:
        self.transforms = transforms

        self.list_of_files = []
        self.input_dir = '../pytorch-basics/datasets/oxford-iiit-pet/images/'
        self.target_dir = '../pytorch-basics/datasets/oxford-iiit-pet/annotations/trimaps/'

        if(split == 'trainval'):
            file_list_path = '../pytorch-basics/datasets/oxford-iiit-pet/annotations/trainval.txt'
        else:
            file_list_path = '../pytorch-basics/datasets/oxford-iiit-pet/annotations/test.txt'

        with open(file_list_path, 'r') as f:
            for line in f.readlines():
                self.list_of_files.append(line.split(' ')[0])

    def __len__(self):
        return(len(self.list_of_files))

    def __getitem__(self, index):
        input_path = self.input_dir + self.list_of_files[index] + '.jpg'
        target_path = self.target_dir + self.list_of_files[index] + '.png'

        input = plt.imread(input_path, format= 'jpg')
        target = plt.imread(target_path, format= 'png')

        for transform in self.transforms:
            input = transform(input)
            target = transform(target)

        return (input, target)

In [None]:
training_data = OxfordPetDataset(split= 'trainval', transforms= [ToTensor(), Resize((256, 256))])
test_data = OxfordPetDataset(split= 'test', transforms= [ToTensor(), Resize((256, 256))])

In [None]:
train_dataloader = DataLoader(training_data, batch_size= batch_size, shuffle= True, drop_last= True)
test_dataloader = DataLoader(test_data, batch_size= batch_size, shuffle= True, drop_last= True)

In [None]:
class DownwardBlock(N.Module):
    def __init__(self, in_channels, out_channels, device, stride= 1, kernel_size= 3, padding= 'same'):
        super(DownwardBlock, self).__init__()

        self.enc_block = N.Sequential(
            N.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True),

            N.Conv2d(out_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True)
        )
    
    def forward(self, input):
        output = self.enc_block(input)
        return output

In [None]:
class DownwardHalf(N.Module):
    def __init__(self, filters, device):
        super(DownwardHalf, self).__init__()
        self.down_blocks = N.ModuleList()
        self.mpool = N.MaxPool2d(kernel_size= 2, stride= 2)

        for i in range(len(filters)):
            if i == 0:
                self.down_blocks.append(DownwardBlock(3, filters[i], device= device))
            else:
                self.down_blocks.append(DownwardBlock(filters[i-1], filters[i], device= device))

    def forward(self, input):
        skip_conn = []
        for i in range(len(self.down_blocks)):
            input = self.down_blocks[i](input)
            skip_conn.append(input)
            input = self.mpool(input)
        
        return input, skip_conn

In [None]:
class UpwardBlock(N.Module):
    def __init__(self, skip_conn, in_channels, out_channels, device, stride= 1, kernel_size= 3, padding=' same'):
        super(UpwardBlock, self).__init__()

        self.tconv = N.ConvTranspose2d(in_channels, in_channels, kernel_size, stride, device= device) #add padding here; also change kernel_size and stride
        self.dec_block = N.Sequential(
            N.Conv2d(in_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True),

            N.Conv2d(out_channels, out_channels, kernel_size, stride, padding, device= device, bias= False),
            N.BatchNorm2d(out_channels, device= device),
            N.ReLU(inplace= True)
        )
    
    def forward(self, input, skip_conn):
        tconv_output = self.tconv(input)
        concat_output = torch.cat([tconv_output, skip_conn], dim= 1)
        output = self.dec_block(concat_output)

        return output

In [None]:
class UpwardHalf(N.Module):
    def __init__(self):
        super(UpwardHalf, self).__init__()
        self.up_blocks = N.ModuleList()
