In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import numpy as np # linear algebra
import pandas as pd
data_dir = ['../input/lyft-udacity-challenge/data'+i+'/data'+i for i in ['A','B','C','D','E']]

#5 sets of data

In [2]:
data_dir

In [3]:
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import albumentations as A
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import random
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.cuda import amp

In [4]:
!pip install -q segmentation_models_pytorch
!pip install -q scikit-learn==1.0

In [5]:
class CFG:
    epochs = 10
    lr = 1e-4
    min_lr        = 1e-6
    num_classes = 12
    seed = 101
    batch_size = 16
    num_epochs = 10
    num_workers = 2
    shuffle = True
    device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
str(CFG.seed)

In [7]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed()

In [9]:
class ImageDataset(Dataset):
    def __init__(self, data_dir,transform = None):
        image_paths = [i+'/CameraRGB/' for i in data_dir]
        mask_paths = [i+'/CameraSeg/' for i in data_dir]
        
        self.images, self.masks = [],[]
        self.transform = transform
        
        for i in image_paths:
            all_images = os.listdir(i)
            self.images.extend([i+img for img in all_images])
        for i in mask_paths:
            all_masks = os.listdir(i)
            self.masks.extend([i+img for img in all_masks])
            
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,index):
        img_path = self.images[index]
        mask_path = self.masks[index]
        
        img = np.array(Image.open(img_path))
        mask = np.array(Image.open(mask_path))
        
        if self.transform is not None:
            augmentations = self.transform(image = img, mask = mask)
            img = augmentations['image']
            #img = np.transpose(img,(1,2,0))
            mask = augmentations['mask']
            mask = torch.max(mask,dim=2)[0]
            #mask = np.reshape(mask,(1,256,256))
            
        return img,mask
        
        

In [8]:
t1 = A.Compose([
    A.Resize(256,256),
    A.augmentations.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

In [10]:
df = ImageDataset(data_dir,transform = t1)

In [11]:
img,msk = df.__getitem__(0)

In [12]:
img.shape

In [13]:
msk.shape

In [14]:
from PIL import Image as im
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

In [15]:
from PIL import Image as im
plt.imshow(np.reshape(msk,(256,256)))
plt.show()

In [16]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,1,1,bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)

        )
    def forward(self,x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self,in_channels = 3,out_channels = 1,features = [64,128,256,512]
    ):
        super(UNET,self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size= 2,stride =2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels,feature))
            in_channels = feature

        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2,feature, kernel_size = 2, stride = 2
                )
            )
            self.ups.append(DoubleConv(feature*2,feature))

        self.bottleneck = DoubleConv(features[-1],features[-1]*2)
        self.final_conv = nn.Conv2d(features[0],out_channels,kernel_size = 1)

    def forward(self,x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections.reverse()

        for idx in range(0,len(self.ups),2):
            x  = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape!=skip_connection.shape:
                x = TF.resize(x,size = skip_connection[2:])
            concat_skip = torch.cat((skip_connection,x),dim = 1)

            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

def test():
    x = torch.randn((3,3,256,256))
    p = torch.randn((3,1,256,256))
    model = UNET(in_channels= 3,out_channels = 1)
    preds = model(x)
    print(preds.shape)
    print(p.shape)
    assert preds.shape == p.shape

In [None]:
test()

In [17]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

In [18]:
def prepare_loaders(data_dir):
    
    data = ImageDataset(data_dir,transform = t1)
    train_size = int(0.9 * data.__len__())
    valid_size = int(0.2*train_size)
    test_size = data.__len__() - train_size
    train_size = train_size-valid_size
    
    train_dataset,valid_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, valid_size,test_size])
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=CFG.shuffle, pin_memory=True, num_workers = CFG.num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=CFG.shuffle, pin_memory=True, num_workers = CFG.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=CFG.shuffle, pin_memory=True, num_workers = CFG.num_workers)
   
    return train_loader,valid_loader,test_loader

In [19]:
pip install torchsummary 

In [None]:
def train_fn(loader,model,optimizer,loss_fn,scaler):
    loop = tqdm(enumerate(loader),total=len(loader))

    for batch_idx, (data,targets) in loop:
        data = data.to(device = CFG.device)
        targets = targets.to(CFG.device)
        targets = targets.type(torch.long)
        
        

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions,targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss  = loss.item())

In [22]:
model = UNET(in_channels = 3,out_channels=13).to(CFG.device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = CFG.lr)
train_loader, val_loader ,test_loader = prepare_loaders(data_dir)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(CFG.epochs):
        train_fn(train_loader,model,optimizer,loss_fn,scaler)
        checkpoint = {
            "state_dict":model.state_dict(),
            "optimizer":optimizer.state_dict(),


        }
        #save_checkpoint(checkpoint)
        #check_accuracy(val_loader,model,device=CFG.device)
        #save_predictions_as_imgs(

            #val_loader,model,folder = "save_images/",device = CFG.DEVICE
        #)


In [23]:
def metric(loader, model):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(CFG.device)
            y = y.to(CFG.device)
            softmax = nn.Softmax(dim=1)
            preds = torch.argmax(softmax(model(x)),axis=1)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / ((preds + y).sum() + 1e-8)

    print(f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader)}")

In [24]:
metric(test_loader, model)

In [None]:
torch.save(model.state_dict(),"checpoint.pth")
PATH = "checpoint.pth"

In [None]:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))


In [26]:
for x, y in test_loader:
    
    x = x.to(CFG.device)
    fig , ax =  plt.subplots(3, 3, figsize=(18, 18))
    softmax = nn.Softmax(dim=1)
    preds = torch.argmax(softmax(model(x)),axis=1).to('cpu')
    img1 = np.transpose(np.array(x[0,:,:,:].to('cpu')),(1,2,0))
    preds1 = np.array(preds[0,:,:])
    mask1 = np.array(y[0,:,:])
    img2 = np.transpose(np.array(x[1,:,:,:].to('cpu')),(1,2,0))
    preds2 = np.array(preds[1,:,:])
    mask2 = np.array(y[1,:,:])
    img3 = np.transpose(np.array(x[2,:,:,:].to('cpu')),(1,2,0))
    preds3 = np.array(preds[2,:,:])
    mask3 = np.array(y[2,:,:])
    ax[0,0].set_title('Image')
    ax[0,1].set_title('Prediction')
    ax[0,2].set_title('Mask')
    ax[1,0].set_title('Image')
    ax[1,1].set_title('Prediction')
    ax[1,2].set_title('Mask')
    ax[2,0].set_title('Image')
    ax[2,1].set_title('Prediction')
    ax[2,2].set_title('Mask')
    ax[0][0].axis("off")
    ax[1][0].axis("off")
    ax[2][0].axis("off")
    ax[0][1].axis("off")
    ax[1][1].axis("off")
    ax[2][1].axis("off")
    ax[0][2].axis("off")
    ax[1][2].axis("off")
    ax[2][2].axis("off")
    ax[0][0].imshow(img1)
    ax[0][1].imshow(preds1)
    ax[0][2].imshow(mask1)
    ax[1][0].imshow(img2)
    ax[1][1].imshow(preds2)
    ax[1][2].imshow(mask2)
    ax[2][0].imshow(img3)
    ax[2][1].imshow(preds3)
    ax[2][2].imshow(mask3)   
    
    
    fig , ax =  plt.subplots(3, 3, figsize=(18, 18))
    softmax = nn.Softmax(dim=1)
    preds = torch.argmax(softmax(model(x)),axis=1).to('cpu')
    img1 = np.transpose(np.array(x[10,:,:,:].to('cpu')),(1,2,0))
    preds1 = np.array(preds[10,:,:])
    mask1 = np.array(y[10,:,:])
    img2 = np.transpose(np.array(x[11,:,:,:].to('cpu')),(1,2,0))
    preds2 = np.array(preds[11,:,:])
    mask2 = np.array(y[11,:,:])
    img3 = np.transpose(np.array(x[12,:,:,:].to('cpu')),(1,2,0))
    preds3 = np.array(preds[12,:,:])
    mask3 = np.array(y[12,:,:])
    ax[0,0].set_title('Image')
    ax[0,1].set_title('Prediction')
    ax[0,2].set_title('Mask')
    ax[1,0].set_title('Image')
    ax[1,1].set_title('Prediction')
    ax[1,2].set_title('Mask')
    ax[2,0].set_title('Image')
    ax[2,1].set_title('Prediction')
    ax[2,2].set_title('Mask')
    ax[0][0].axis("off")
    ax[1][0].axis("off")
    ax[2][0].axis("off")
    ax[0][1].axis("off")
    ax[1][1].axis("off")
    ax[2][1].axis("off")
    ax[0][2].axis("off")
    ax[1][2].axis("off")
    ax[2][2].axis("off")
    ax[0][0].imshow(img1)
    ax[0][1].imshow(preds1)
    ax[0][2].imshow(mask1)
    ax[1][0].imshow(img2)
    ax[1][1].imshow(preds2)
    ax[1][2].imshow(mask2)
    ax[2][0].imshow(img3)
    ax[2][1].imshow(preds3)
    ax[2][2].imshow(mask3) 
    break