In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import PIL 
from PIL import Image

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [2]:
target_path = "/kaggle/input/d/julinmaloof/the-oxfordiiit-pet-dataset/annotations/trimaps" 
input_path = "/kaggle/input/d/julinmaloof/the-oxfordiiit-pet-dataset/images"


target_annots = sorted([
    os.path.join(target_path, file) for file in os.listdir(target_path) 
    if file.endswith(".png") and not file.startswith(".")
])

input_imgs = sorted([
    os.path.join(input_path, file) for file in os.listdir(input_path) 
    if file.endswith('.jpg') and not file.startswith('.')
])

print(len(target_annots)) 
print(len(input_imgs)) 



7390
7390


In [3]:

transform = transforms.Compose([
    transforms.Resize((160,160)),
    transforms.ToTensor()
])


def get_prepared_data(img_paths, annot_paths):
    imgs = [] 
    annots = [] 
    for i, (img_path, annot_path ) in enumerate(zip(img_paths, annot_paths)):
        img = Image.open(img_path) 
        img = transform(img) 
        
        if img.shape[0] != 3:
            continue 
        
        annot = Image.open(annot_path) 
        annot = transform(annot) 
        
        if annot.shape[0] != 1:
            continue 
            
        imgs.append(img) 
        annots.append(annot) 
        
    return torch.stack(imgs, dim=0), torch.stack(annots, dim=0)
        




In [4]:
train_size = int(0.95 * len(input_imgs))
train_imgs_paths = input_imgs[:train_size] 
train_targets_paths = target_annots[:train_size] 

train_imgs, train_targets = get_prepared_data(train_imgs_paths, train_targets_paths)

test_imgs_path = input_imgs[train_size:]
test_targets_path = target_annots[train_size:] 

test_imgs, test_targets = get_prepared_data(test_imgs_path, test_targets_path)
    

In [5]:
class CustomDataset(Dataset):
    def __init__(self, imgs, annots):
        self.imgs = imgs
        self.annots = annots 
        
    def __len__(self):
        return len(self.imgs) 
    
    def __getitem__(self, index):
        return self.imgs[index], self.annots[index] 
    
    

train_ds = CustomDataset(train_imgs, train_targets) 
test_ds = CustomDataset(test_imgs, test_targets) 

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True) 
test_dl = DataLoader(test_ds, batch_size=32) 
    

In [6]:
print(len(train_dl.dataset))
print(len(test_dl.dataset))

7008
370


In [7]:
annot = Image.open(train_targets_paths[0])
trns = transforms.Compose([
    transforms.ToTensor()
])
annot = torch.round(trns(annot) /0.0039)

torch.unique(annot)

tensor([1., 2., 3.])

In [8]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()  
        self.relu_1 = nn.ReLU() 
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding=1) 
        self.bn_1 = nn.BatchNorm2d(out_channels) 
        
        self.relu_2 = nn.ReLU() 
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding=1) 
        self.bn_2 = nn.BatchNorm2d(out_channels) 
        
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 
        
        self.resid = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, padding=0) 
        
    def forward(self, x, prev_block):
        x = self.relu_1(x) 
        x = self.conv_1(x) 
        x = self.bn_1(x) 
        x = self.relu_2(x) 
        x = self.conv_2(x) 
        x = self.bn_2(x) 
        x = self.pool(x)
        
        resid = self.resid(prev_block) 
        x = torch.add(x, resid) 
        return x

In [9]:
x = torch.randn(1,3,160,160) 
prev_block = x 
down = DownBlock(3, 64) 
y = down(x ,prev_block) 
print(y.shape)

torch.Size([1, 64, 80, 80])


In [10]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__() 
        self.relu_1 = nn.ReLU() 
        self.conv_1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, padding=1) 
        self.bn_1 = nn.BatchNorm2d(out_channels) 
        
        self.relu_2 = nn.ReLU() 
        self.conv_2 = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, padding=1) 
        self.bn_2 = nn.BatchNorm2d(out_channels)
        
        self.up_sample_1 = nn.Upsample(scale_factor=2) 
        self.up_sample_2 = nn.Upsample(scale_factor=2) 
        
        self.conv_3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 
        
    def forward(self, x, prev_block):
        x = self.relu_1(x) 
        x = self.conv_1(x) 
        x = self.bn_1(x) 
        
        x = self.relu_2(x) 
        x = self.conv_2(x) 
        x = self.bn_2(x) 
        
        x = self.up_sample_1(x) 
       
        
        resid = self.up_sample_2(prev_block) 
        resid = self.conv_3(resid) 
   
        x = torch.add(x, resid) 
        return x

In [11]:
x = torch.randn(1,3,160,160) 
prev_block = x 

in_channels = 3
out_channels = 32 
upblock = UpBlock(in_channels, out_channels) 

z = upblock(x, prev_block) 
print(z.shape)

torch.Size([1, 32, 320, 320])


In [12]:
class CustomModel(nn.Module):
    def __init__(self, num_classes=3):
        super(CustomModel, self).__init__() 
        self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1) 
        self.bn = nn.BatchNorm2d(32) 
        self.relu = nn.ReLU() 
        
        
        self.outputs = nn.Conv2d(32, num_classes, kernel_size=3,padding=1) 
        self.softmax = nn.Softmax(dim=1) 
        
    def forward(self, x):
        x = self.conv(x) 
        x = self.bn(x) 
        x = self.relu(x) 

        prev_block = x
        in_channels = 32
        
        for filters in [64,128,256]:
            x = DownBlock(in_channels, filters)(x, prev_block) 
            prev_block = x 
            in_channels = filters 
            
        for filters in [256,128,64,32]:
            x = UpBlock(in_channels, filters)(x, prev_block) 
            prev_block = x 
            in_channels = filters 
            
        outputs = self.outputs(x)
        outputs = self.softmax(outputs)
        return outputs
      

x = torch.randn(1,3,160,160) 
y = CustomModel()(x) 
print(y.shape)

torch.Size([1, 3, 160, 160])


In [28]:
model = CustomModel()  

device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 
model = model.to(device)
print(device)

cuda


In [14]:
optim = torch.optim.Adam(params = model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [26]:
batch_size = 32

def train_model(model, data, optim, loss_fn, epochs):
    size = len(data.dataset) 
    all_losses = [] 
    all_accuracies = [] 
    
    model.train() 
    for epoch in range(epochs):
        losses = [] 
        accuracies = [] 
        print(f'\n\nEpoch : {epoch+1} ------------------------------------')
        for batch, (X,y) in enumerate(data):
            X = X.to(device) 
            y = torch.round(y /0.0039)
            y = y.to(torch.long) -1 
            y = y.view(-1, 160,160)
            y = y.to(device)
            preds = model(X) 
            loss = loss_fn(preds, y) 
            loss.backward() 
            
            optim.step() 
            optim.zero_grad()  
            
            loss = loss.item() 
            losses.append(loss) 
            points = y.shape[0]*y.shape[1]*y.shape[2] 
            accu = (preds.argmax(1) == y).type(torch.float).sum().item() / points
            accuracies.append(accu)  
            current = batch * batch_size + len(y) 
            if batch % 50 ==0:
                print(f"Loss : {loss:.3f} | Accuracy : {accu:.3f} Current : [{current :>3d}/{size}]")
            
        all_losses.append(losses) 
        all_accuracies.append(accuracies) 
        
    
    return all_losses, all_accuracies

    

In [27]:
losses, accuracies = train_model(model, train_dl, optim, loss_fn, 1)



Epoch : 1 ------------------------------------


RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [None]:
x = 2.5234315
print(f"{x:.3f}")

In [None]:
logits = torch.randn(32,3,160,160) 
target = torch.randint(0,3,(32,1,160,160), dtype=torch.long) 

target = target.view(-1,160,160)
loss = loss_fn(logits, target) 

print(logits.shape) 
print(target.shape)
print(loss)