In [None]:
!pip install albumentations==0.4.6
import os
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import cv2
from torch import nn,optim
import torch
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import matplotlib.pyplot as plt

class UNET(nn.Module):
            def __init__(self):
                        super().__init__()
                        
                        #DOWN SAMPLE LAYERS

                        self.conv1_down_1 = nn.Conv2d(3,64,3,padding=1)
                        self.bn1_down_1 = nn.BatchNorm2d(64)
                        self.conv1_down_2 = nn.Conv2d(64,64,3,padding=1)
                        self.bn1_down_2 = nn.BatchNorm2d(64)

                        self.conv2_down_1 = nn.Conv2d(64,128,3,padding=1)
                        self.bn2_down_1 = nn.BatchNorm2d(128)
                        self.conv2_down_2 = nn.Conv2d(128,128,3,padding=1)
                        self.bn2_down_2 = nn.BatchNorm2d(128)
                        
                        self.conv3_down_1 = nn.Conv2d(128,256,3,padding=1)
                        self.bn3_down_1 = nn.BatchNorm2d(256)
                        self.conv3_down_2 = nn.Conv2d(256,256,3,padding=1)
                        self.bn3_down_2 = nn.BatchNorm2d(256)

                        self.conv4_down_1 = nn.Conv2d(256,512,3,padding=1)
                        self.bn4_down_1 = nn.BatchNorm2d(512)
                        self.conv4_down_2 = nn.Conv2d(512,512,3,padding=1)
                        self.bn4_down_2 = nn.BatchNorm2d(512)

                        self.conv5_down_1 = nn.Conv2d(512,1024,3,padding=1)
                        self.bn5_down_1 = nn.BatchNorm2d(1024)
                        self.conv5_down_2 = nn.Conv2d(1024,1024,3,padding=1)
                        self.bn5_down_2 = nn.BatchNorm2d(1024)

                        self.pool = nn.MaxPool2d(2,2)

                        #UP SAMPLE LAYERS

                        self.up_sample_1 = nn.ConvTranspose2d(1024,512,2,2)
                        self.conv1_up_1 = nn.Conv2d(1024,512,3,padding=1)
                        self.bn1_up_1 = nn.BatchNorm2d(512)
                        self.conv1_up_2 = nn.Conv2d(512,512,3,padding=1)
                        self.bn1_up_2 = nn.BatchNorm2d(512)

                        self.up_sample_2 = nn.ConvTranspose2d(512,256,2,2)
                        self.conv2_up_1 = nn.Conv2d(512,256,3,padding=1)
                        self.bn2_up_1 = nn.BatchNorm2d(256)
                        self.conv2_up_2 = nn.Conv2d(256,256,3,padding=1)
                        self.bn2_up_2 = nn.BatchNorm2d(256)

                        self.up_sample_3 = nn.ConvTranspose2d(256,128,2,2)
                        self.conv3_up_1 = nn.Conv2d(256,128,3,padding=1)
                        self.bn3_up_1 = nn.BatchNorm2d(128)
                        self.conv3_up_2 = nn.Conv2d(128,128,3,padding=1)
                        self.bn3_up_2 = nn.BatchNorm2d(128)

                        self.up_sample_4 = nn.ConvTranspose2d(128,64,2,2)
                        self.conv4_up_1 = nn.Conv2d(128,64,3,padding=1)
                        self.bn4_up_1 = nn.BatchNorm2d(64)
                        self.conv4_up_2 = nn.Conv2d(64,64,3,padding=1)
                        self.bn4_up_2 = nn.BatchNorm2d(64)

                        self.output = nn.Conv2d(64,4,1)

            def forward(self,x):

                        #DOWN SAMPLING
                        cd1 = self.conv1_down_1(x)
                        cd1 = self.bn1_down_1(cd1)
                        cd1 = F.relu(cd1)
                        cd1 = self.conv1_down_2(cd1)
                        cd1 = self.bn1_down_2(cd1)
                        cd1 = F.relu(cd1)
                        
                        cd2 = self.pool(cd1)
                        
                        cd2 = self.conv2_down_1(cd2)
                        cd2 = self.bn2_down_1(cd2)
                        cd2 = F.relu(cd2)
                        cd2 = self.conv2_down_2(cd2)
                        cd2 = self.bn2_down_2(cd2)
                        cd2 = F.relu(cd2)
                        
                        cd3 = self.pool(cd2)
                        
                        cd3 = self.conv3_down_1(cd3)
                        cd3 = self.bn3_down_1(cd3)
                        cd3 = F.relu(cd3)
                        cd3 = self.conv3_down_2(cd3)
                        cd3 = self.bn3_down_2(cd3)
                        cd3 = F.relu(cd3)
                        
                        cd4 = self.pool(cd3)
                        
                        cd4 = self.conv4_down_1(cd4)
                        cd4 = self.bn4_down_1(cd4)
                        cd4 = F.relu(cd4)
                        cd4 = self.conv4_down_2(cd4)
                        cd4 = self.bn4_down_2(cd4)
                        cd4 = F.relu(cd4)
                        
                        cd5 = self.pool(cd4)
                        
                        cd5 = self.conv5_down_1(cd5)
                        cd5 = self.bn5_down_1(cd5)
                        cd5 = F.relu(cd5)
                        cd5 = self.conv5_down_2(cd5)
                        cd5 = self.bn5_down_2(cd5)
                        cd5 = F.relu(cd5)
                        
                        #UP SAMPLING

                        cu = self.up_sample_1(cd5)
                        cu = torch.cat((cd4,cu), dim=1)

                        cu = self.conv1_up_1(cu)
                        cu = self.bn1_up_1(cu)
                        cu = F.relu(cu)
                        cu = self.conv1_up_2(cu)
                        cu = self.bn1_up_2(cu)
                        cu = F.relu(cu)

                        cu = self.up_sample_2(cu)
                        cu = torch.cat((cd3,cu), dim=1)

                        cu = self.conv2_up_1(cu)
                        cu = self.bn2_up_1(cu)
                        cu = F.relu(cu)
                        cu = self.conv2_up_2(cu)
                        cu = self.bn2_up_2(cu)
                        cu = F.relu(cu)

                        cu = self.up_sample_3(cu)
                        cu = torch.cat((cd2,cu), dim=1)
                        

                        cu = self.conv3_up_1(cu)
                        cu = self.bn3_up_1(cu)
                        cu = F.relu(cu)
                        cu = self.conv3_up_2(cu)
                        cu = self.bn3_up_2(cu)
                        cu = F.relu(cu)

                        cu = self.up_sample_4(cu)
                        cu = torch.cat((cd1,cu), dim=1)

                        cu = self.conv4_up_1(cu)
                        cu = self.bn4_up_1(cu)
                        cu = F.relu(cu)
                        cu = self.conv4_up_2(cu)
                        cu = self.bn4_up_2(cu)
                        cu = F.relu(cu)

                        output = F.log_softmax(self.output(cu),dim=1)

                        return output


class Dataset(Dataset):
            def __init__(self,image_dir,mask_dir,transform=None):
                        self.image_dir = image_dir
                        self.mask_dir = mask_dir
                        self.transform = transform
                        self.images = os.listdir(image_dir)
                        self.masks = os.listdir(mask_dir)

            def __len__(self):
                        return len(self.images)

            def __getitem__(self,index):
                        image_path = os.path.join(self.image_dir,self.images[index])
                        mask_path = os.path.join(self.mask_dir,self.masks[index])
                        image = cv2.imread(image_path)
                        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
                        mask = cv2.imread(mask_path,0)
                        mask_0 = cv2.inRange(mask,0,50)
                        mask_1 = cv2.inRange(mask,70,120)
                        mask_2 = cv2.inRange(mask,140,200)
                        mask_3 = cv2.inRange(mask,200,255)
                        array=np.array([mask_0,mask_1,mask_2,mask_3])
                        target=np.argmax(array,axis=0)
                        
                        if not self.transform == None:
                                    augmentations = self.transform(image= image,mask=target)
                                    image = augmentations["image"]
                                    masks = augmentations["mask"]

                        return image,masks

def plotVals(
    losses,
    ylim=(0, 1.5),
    f_name="Losses",
    x_label="Training Step",
    y_label="Training Loss",
):
    x_axis = np.arange(len(losses))
    plt.plot(x_axis, losses, label=y_label + " Curve")
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.ylim(ylim)
    plt.legend()
    plt.title(y_label + " Curve")
    plt.savefig("/content/drive/MyDrive/Segmentation/MRI SEG/Plots/Multi/{}/".format(f_name) + y_label + ".png", bbox_inches="tight")
    plt.close()


training_losses = []
validation_losses = []
accuracies = []

train_transform = A.Compose([A.Resize(height=256,width=256),
                             A.Rotate(limit=30,p=1.0,border_mode=cv2.BORDER_CONSTANT),A.HorizontalFlip(p=0.5),
                             A.VerticalFlip(p=0.1),A.Normalize((0),(1)),
                             ToTensorV2()])

val_transform = A.Compose([A.Resize(height=256,width=256),
                           A.Normalize((0),(1)),
                           ToTensorV2()])

trainset = Dataset("/content/drive/MyDrive/Segmentation/MRI SEG/MRI/Train Images","/content/drive/MyDrive/Segmentation/MRI SEG/MRI/Train Masks",transform=train_transform)
valset = Dataset("/content/drive/MyDrive/Segmentation/MRI SEG/MRI/Train Images","/content/drive/MyDrive/Segmentation/MRI SEG/MRI/Train Masks",transform = val_transform)

n_train = len(trainset)
indices = list(range(n_train))
np.random.shuffle(indices)
split=int(np.floor(0.05*n_train))
train_idx,val_idx = indices[split:],indices[:split]
tsampler,vsampler = SubsetRandomSampler(train_idx),SubsetRandomSampler(val_idx)
train_loader = DataLoader(trainset,sampler=tsampler,batch_size = 4)
valid_loader = DataLoader(valset,sampler=vsampler,batch_size = 4)

model = UNET()
# model.load_state_dict(torch.load("/content/drive/MyDrive/Segmentation/MRI SEG/Model/multi_unet_2.pth"))
model=model.cuda()
optimiser = optim.Adam(model.parameters(),lr=0.0001)
scaler = torch.cuda.amp.GradScaler()
criterion = nn.NLLLoss()
min_acc = 0
epochs = 30
num_correct = 0
num_pixels = 0
dice_score = 0
v_loss = 0
for batch_idx, (images, labels) in enumerate(valid_loader):
    # move tensors to GPU if CUDA is available
    images, labels = images.cuda(), labels.cuda()
    # forward pass: compute predicted outputs by passing inputs to the model
    with torch.cuda.amp.autocast():
         output = model(images)
         loss=criterion(output,labels)
         output = torch.exp(output)
         output = (torch.argmax(output,dim=1)).float()
         v_loss += loss.cpu().item()
    num_correct += (output == labels).sum()
    num_pixels += torch.numel(output)
validation_losses.append(v_loss/len(valid_loader))
accuracy = num_correct/num_pixels*100
accuracies.append(accuracy.cpu().item())
print("\nEPOCH:",0,"\tAccuracy:",accuracy, "\tValidation Loss:",v_loss/len(valid_loader))

for e in range(1,epochs+1):
    model.train()
    loop = tqdm(train_loader)
    for images,labels in loop:
        labels=labels.long()
        images,labels=images.cuda(),labels.cuda()
        optimiser.zero_grad()
        with torch.cuda.amp.autocast():
                    output=model.forward(images)
                    loss=criterion(output,labels)
        scaler.scale(loss).backward()
        scaler.step(optimiser)
        scaler.update()
        loop.set_postfix(loss=loss.item())
        training_losses.append(loss.cpu().item())
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    v_loss = 0
    # iterate over test data
    for batch_idx, (images, labels) in enumerate(valid_loader):
    # move tensors to GPU if CUDA is available
        images, labels = images.cuda(), labels.cuda()
        
    # forward pass: compute predicted outputs by passing inputs to the model
        with torch.cuda.amp.autocast():
            output = model(images)
            loss=criterion(output,labels)
            output = torch.exp(output)
            
            output = (torch.argmax(output,dim=1)).float()
            v_loss += loss.cpu().item()

        num_correct += (output == labels).sum()
        num_pixels += torch.numel(output)
    validation_losses.append(v_loss/len(valid_loader))
    accuracy = num_correct/num_pixels*100
    accuracies.append(accuracy.cpu().item())
    print("\nEPOCH:",e,"\tAccuracy:",accuracy, "\tValidation Loss:",v_loss/len(valid_loader))

    # calculate the batch loss
    if accuracy>min_acc:
        print("Saving model")
        min_acc=accuracy    
        torch.save(model.state_dict(),"/content/drive/MyDrive/Segmentation/MRI SEG/Model/multi_unet_2.pth")

plotVals(
    losses=training_losses,
    ylim=(0, max(training_losses)+0.5),
    f_name="Training Loss",
    x_label="Training Step",
    y_label="Training Loss",
)
plotVals(
    losses=validation_losses,
    ylim=(0, max(validation_losses)+0.5),
    f_name="Validation Loss",
    x_label="Epoch",
    y_label="Validation Loss",
)
plotVals(
    losses=accuracies,
    ylim=(0, 100),
    f_name="Accuracy",
    x_label="Epoch",
    y_label="Pixel Classification Accuracy",
)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/

EPOCH: 0 	Accuracy: tensor(14.2893, device='cuda:0') 	Validation Loss: 1.6346693960103122


100%|██████████| 407/407 [01:04<00:00,  6.27it/s, loss=0.349]



EPOCH: 0 	Accuracy: tensor(96.1199, device='cuda:0') 	Validation Loss: 0.323328667066314
Saving model


100%|██████████| 407/407 [01:04<00:00,  6.28it/s, loss=0.159]



EPOCH: 0 	Accuracy: tensor(96.9333, device='cuda:0') 	Validation Loss: 0.172585242851214
Saving model


100%|██████████| 407/407 [01:04<00:00,  6.31it/s, loss=0.185]



EPOCH: 0 	Accuracy: tensor(96.6444, device='cuda:0') 	Validation Loss: 0.1309696656059135


100%|██████████| 407/407 [01:04<00:00,  6.32it/s, loss=0.069]



EPOCH: 0 	Accuracy: tensor(97.1013, device='cuda:0') 	Validation Loss: 0.09922263110903176
Saving model


100%|██████████| 407/407 [01:04<00:00,  6.28it/s, loss=0.0934]



EPOCH: 0 	Accuracy: tensor(97.0814, device='cuda:0') 	Validation Loss: 0.10682703588496555


100%|██████████| 407/407 [01:04<00:00,  6.31it/s, loss=0.0499]



EPOCH: 0 	Accuracy: tensor(97.2323, device='cuda:0') 	Validation Loss: 0.08892118795351549
Saving model


100%|██████████| 407/407 [01:06<00:00,  6.12it/s, loss=0.104]



EPOCH: 0 	Accuracy: tensor(97.2516, device='cuda:0') 	Validation Loss: 0.08048686402087862
Saving model


100%|██████████| 407/407 [01:04<00:00,  6.26it/s, loss=0.0433]



EPOCH: 0 	Accuracy: tensor(97.3199, device='cuda:0') 	Validation Loss: 0.07701504518362609
Saving model


100%|██████████| 407/407 [01:04<00:00,  6.29it/s, loss=0.0522]



EPOCH: 0 	Accuracy: tensor(97.2903, device='cuda:0') 	Validation Loss: 0.07972988537089391


100%|██████████| 407/407 [01:04<00:00,  6.32it/s, loss=0.0609]



EPOCH: 0 	Accuracy: tensor(97.2614, device='cuda:0') 	Validation Loss: 0.0759578448804942


100%|██████████| 407/407 [01:04<00:00,  6.29it/s, loss=0.0554]



EPOCH: 0 	Accuracy: tensor(97.3764, device='cuda:0') 	Validation Loss: 0.07274805796755986
Saving model


100%|██████████| 407/407 [01:05<00:00,  6.23it/s, loss=0.0747]



EPOCH: 0 	Accuracy: tensor(97.0840, device='cuda:0') 	Validation Loss: 0.0836539552970366


100%|██████████| 407/407 [01:04<00:00,  6.26it/s, loss=0.159]



EPOCH: 0 	Accuracy: tensor(97.3748, device='cuda:0') 	Validation Loss: 0.07336770218204368


100%|██████████| 407/407 [01:05<00:00,  6.25it/s, loss=0.0364]



EPOCH: 0 	Accuracy: tensor(97.3863, device='cuda:0') 	Validation Loss: 0.07253034345128319
Saving model


100%|██████████| 407/407 [01:05<00:00,  6.25it/s, loss=0.0655]



EPOCH: 0 	Accuracy: tensor(97.3778, device='cuda:0') 	Validation Loss: 0.07144215821542522


100%|██████████| 407/407 [01:05<00:00,  6.23it/s, loss=0.0775]



EPOCH: 0 	Accuracy: tensor(97.3364, device='cuda:0') 	Validation Loss: 0.07041824879971417


100%|██████████| 407/407 [01:05<00:00,  6.21it/s, loss=0.0954]



EPOCH: 0 	Accuracy: tensor(97.3385, device='cuda:0') 	Validation Loss: 0.09880645979534496


100%|██████████| 407/407 [01:05<00:00,  6.24it/s, loss=0.0393]



EPOCH: 0 	Accuracy: tensor(97.3695, device='cuda:0') 	Validation Loss: 0.0727871023118496


100%|██████████| 407/407 [01:04<00:00,  6.28it/s, loss=0.0345]



EPOCH: 0 	Accuracy: tensor(97.3560, device='cuda:0') 	Validation Loss: 0.06921309656040235


100%|██████████| 407/407 [01:04<00:00,  6.31it/s, loss=0.0356]



EPOCH: 0 	Accuracy: tensor(97.2319, device='cuda:0') 	Validation Loss: 0.09473963200368664


100%|██████████| 407/407 [01:04<00:00,  6.33it/s, loss=0.0678]



EPOCH: 0 	Accuracy: tensor(97.3564, device='cuda:0') 	Validation Loss: 0.07378191924230619


100%|██████████| 407/407 [01:04<00:00,  6.34it/s, loss=0.0349]



EPOCH: 0 	Accuracy: tensor(97.4257, device='cuda:0') 	Validation Loss: 0.07126809555021199
Saving model


100%|██████████| 407/407 [01:04<00:00,  6.28it/s, loss=0.055]



EPOCH: 0 	Accuracy: tensor(97.3361, device='cuda:0') 	Validation Loss: 0.07347118380394849


100%|██████████| 407/407 [01:04<00:00,  6.33it/s, loss=0.192]



EPOCH: 0 	Accuracy: tensor(97.1598, device='cuda:0') 	Validation Loss: 0.07473786649378864


100%|██████████| 407/407 [01:04<00:00,  6.33it/s, loss=0.037]



EPOCH: 0 	Accuracy: tensor(97.3668, device='cuda:0') 	Validation Loss: 0.07448158819567073


100%|██████████| 407/407 [01:04<00:00,  6.32it/s, loss=0.0341]



EPOCH: 0 	Accuracy: tensor(97.2845, device='cuda:0') 	Validation Loss: 0.07329791394824331


100%|██████████| 407/407 [01:04<00:00,  6.35it/s, loss=0.0615]



EPOCH: 0 	Accuracy: tensor(97.2390, device='cuda:0') 	Validation Loss: 0.074742602692409


100%|██████████| 407/407 [01:04<00:00,  6.34it/s, loss=0.08]



EPOCH: 0 	Accuracy: tensor(97.3774, device='cuda:0') 	Validation Loss: 0.07088587365367195


100%|██████████| 407/407 [01:04<00:00,  6.36it/s, loss=0.0708]



EPOCH: 0 	Accuracy: tensor(97.4358, device='cuda:0') 	Validation Loss: 0.07047555781900883
Saving model


100%|██████████| 407/407 [01:04<00:00,  6.34it/s, loss=0.0302]



EPOCH: 0 	Accuracy: tensor(97.2492, device='cuda:0') 	Validation Loss: 0.07511622086167336


In [None]:
print(output.shape,labels.shape)

torch.Size([4, 256, 256]) torch.Size([4, 1, 256, 256])
