In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install tensorboardX

In [None]:
import os
import cv2
import sys
import tqdm
import torch
import datetime
import torchvision

import numpy as np
import torch.nn as nn
import skimage.io as io
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms

from PIL import Image
from skimage import img_as_float
from tensorboardX import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support

In [None]:
root = "/content/drive/MyDrive/C_scapes"

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
model_choice = 'segnet' # Select which model to train / test, can be 'fcn' or 'segnet'
weighted_choice = False # Select if loss function should be weighted or not
train = False # Select if you want to train or evaluate the model
img_size = 256

In [None]:
np.random.seed(42)


In [None]:
class _DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_conv_layers):
        super(_DecoderBlock, self).__init__()
        middle_channels = in_channels // 2
        layers = [
            nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2),
            nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True)
        ]
        layers += [
                      nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
                      nn.BatchNorm2d(middle_channels),
                      nn.ReLU(inplace=True),
                  ] * (num_conv_layers - 2)
        layers += [
            nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        self.decode = nn.Sequential(*layers)

    def forward(self, x):
        return self.decode(x)


class SegNet(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(SegNet, self).__init__()
        vgg = torchvision.models.vgg19_bn(pretrained=True)
        features = list(vgg.features.children())
        self.enc1 = nn.Sequential(*features[0:7])
        self.enc2 = nn.Sequential(*features[7:14])
        self.enc3 = nn.Sequential(*features[14:27])
        self.enc4 = nn.Sequential(*features[27:40])
        self.enc5 = nn.Sequential(*features[40:])

        self.dec5 = nn.Sequential(
            *([nn.ConvTranspose2d(512, 512, kernel_size=2, stride=2)] +
              [nn.Conv2d(512, 512, kernel_size=3, padding=1),
               nn.BatchNorm2d(512),
               nn.ReLU(inplace=True)] * 4)
        )
        self.dec4 = _DecoderBlock(1024, 256, 4)
        self.dec3 = _DecoderBlock(512, 128, 4)
        self.dec2 = _DecoderBlock(256, 64, 2)
        self.dec1 = _DecoderBlock(128, num_classes, 2)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        enc4 = self.enc4(enc3)
        enc5 = self.enc5(enc4)

        dec5 = self.dec5(enc5)
        dec4 = self.dec4(torch.cat([enc4, dec5], 1))
        dec3 = self.dec3(torch.cat([enc3, dec4], 1))
        dec2 = self.dec2(torch.cat([enc2, dec3], 1))
        dec1 = self.dec1(torch.cat([enc1, dec2], 1))
        return dec1

In [None]:
if model_choice == 'fcn':
    model = FCN8s(39).to(device)
else:
    model = SegNet(39).to(device)

In [None]:
batch = 16


In [None]:
train_transforms = transforms.Compose([
        transforms.Resize((img_size,img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

test_transforms = transforms.Compose([
        transforms.Resize((img_size,img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, root, size, train=True, data_transforms=None):
        self.root = root
        if train:
            self.img_path = root + '/img/' + '/subset_train/'
            self.seg_path = root + '/seg/' + '/subset_train/'
        else:
            self.img_path = root + '/img/' + '/subset_val/'
            self.seg_path = root + '/seg/' + '/subset_val/'

        self.img_list, self.seg_list = os.listdir(self.img_path), os.listdir(self.seg_path)
        self.size = size
        self.data_transforms = data_transforms

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_path, self.img_list[index]))
        img = self.data_transforms(img)

        seg = cv2.resize(cv2.imread(os.path.join(self.seg_path, self.seg_list[index])), self.size, cv2.INTER_NEAREST)
        seg = torch.from_numpy(seg[:,:,0]).long()

        return img, seg

In [None]:
#train_dataset = SegmentationDataset('/content/drive/MyDrive/Cityscapes/final_train/', (img_size,img_size), data_transforms=train_transforms)
#train_dataloader = DataLoader(train_dataset, num_workers=2, shuffle=True, batch_size=batch)

test_dataset = SegmentationDataset('/content/drive/MyDrive/C_scapes/', (img_size,img_size), data_transforms=test_transforms, train=False)
test_dataloader = DataLoader(test_dataset, num_workers=2, shuffle=True, batch_size=1)

In [None]:
for i, data in enumerate(train_dataloader):
    print(data[0].shape, data[1].shape)
    print(np.unique(data[1].numpy()))
    break

In [None]:
if not train:
    #path_fcn_unweighted = './all_models/fcn/models/fcn_upsample_best.pth'
    path_segnet_unweighted = '/content/drive/MyDrive/C_scapes/segnet/models/segnet_upsample_best.pth'

In [None]:
def checkAccuracy(pred, truth, batch_size):
    pred = pred.cpu().numpy()
    truth = truth.cpu().numpy()
    acc = np.count_nonzero(pred==truth) / (256*256*batch_size)
    return acc

def checkiou(pred, truth):
    intersection = pred & truth
    union = pred | truth
    iou = torch.mean((torch.sum(intersection).float()/torch.sum(union).float()).float())
    return iou

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def find_weights(dataset, num_classes):
    weights = np.ones((num_classes, 1))
    for i, data in enumerate(dataset):
        seg = data[1].numpy()
        cl, count = np.unique(seg, return_counts=True)
        for j in range(len(cl)):
            weights[cl[j]] += count[j]
    weights = weights / np.sum(weights)
    return np.reshape(np.reciprocal(weights), (weights.shape[0],1))

def checkclasswisedice(y_true, predicted, k=1):
    labels = np.unique(y_true)
    if k in labels:
        dice_score = 2.0 * np.sum(predicted[y_true==k]==k) / (np.sum(predicted[y_true==k]==k) + np.sum(y_true[y_true==k]==k))
        return dice_score, 1
    dice_score = 0.0

In [None]:
ROOT_DIR = '/content/drive/MyDrive/C_scapes/'
now = model_choice

if not os.path.exists(ROOT_DIR):
    os.makedirs(ROOT_DIR)

if not os.path.exists(ROOT_DIR + now):
    os.makedirs(ROOT_DIR + now)

LOG_DIR = ROOT_DIR + now + '/logs/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

MODEL_DIR = ROOT_DIR + now + '/models/'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

summary_writer = SummaryWriter(LOG_DIR)

In [None]:
if weighted_choice:
    weights = torch.from_numpy(find_weights(train_dataset, 39)).float().to(device)
    criterion = nn.CrossEntropyLoss(weights).to(device)
else:
    criterion = nn.CrossEntropyLoss().to(device)

max_epoch = 50
optimizer = torch.optim.Adam(params=model.parameters(), lr=1.0e-3, betas=(0.9,0.999), weight_decay=1.0e-4)

In [None]:
def eval_model(model, val_set):
    total_acc, total_iou = 0, 0
    total_count = len(val_set)
    print('Total number of samples in the evaluation set is ', len(val_set))

    val_loader = DataLoader(val_set, num_workers=2, shuffle=False, batch_size=1)

    model.eval()
    len_dataloader = len(val_loader)
    for i, data in enumerate(val_loader):
        imgs = data[0].float().to(device)
        labels = data[1].long().to(device)

        with torch.no_grad():
            out= model(imgs)

        out_labels = torch.argmax(out, dim=1)

        prediction = out_labels.cpu().numpy()
        truth = labels.cpu().numpy()
        acc = np.count_nonzero(prediction == truth) / (img_size*img_size)
        total_acc += acc

        iou = checkiou(out_labels, labels)
        total_iou += iou

    return total_acc / total_count, total_iou / total_count

In [None]:
def visualize(model, val_set):
    index = np.random.randint(len(val_set))
    data = val_set.__getitem__(0)
    imgs = data[0].unsqueeze(0).float().to(device)
    labels = data[1].unsqueeze(0).long().to(device)

    with torch.no_grad():
        out= model(imgs)

    print(out)

    out_labels = torch.argmax(out, dim=1)

    img = imgs[0,:,:,:].permute(1,2,0).detach().cpu().numpy()
    gt = labels[0,:,:].detach().cpu().numpy()
    pred = out_labels[0,:,:].detach().cpu().numpy()

    return img, gt, pred

In [None]:
def visualize(model, val_set):
    index = np.random.randint(len(val_set))
    data = val_set.__getitem__(0)
    imgs = data[0].unsqueeze(0).float().to(device)
    labels = data[1].unsqueeze(0).long().to(device)

    print('imgs:', imgs.shape)
    print('labels:', labels.shape)

    out = None
    with torch.no_grad():
        out = model(imgs)
        print('out:', out.shape)

    out_labels = torch.argmax(out, dim=1)

    img = imgs[0,:,:,:].permute(1,2,0).detach().cpu().numpy()
    gt = labels[0,:,:].detach().cpu().numpy()
    pred = out_labels[0,:,:].detach().cpu().numpy()

    return img, gt, pred


In [None]:
#correct
def visualize(model, val_set):
    index = np.random.randint(len(val_set))
    data = val_set.__getitem__(0)
    imgs = data[0].unsqueeze(0).float().to(device)
    labels = data[1].unsqueeze(0).long().to(device)

    with torch.no_grad():
        out = model(imgs)
        out_channel1 = out[:, 1, :, :]
        out_labels = torch.argmax(out, dim=1)

    img = imgs[0,:,:,:].permute(1,2,0).detach().cpu().numpy()
    gt = labels[0,:,:].detach().cpu().numpy()
    pred = out_labels[0,:,:].detach().cpu().numpy()

    return img, gt, pred, out_channel1


In [None]:
path_segnet_unweighted = '/content/drive/MyDrive/C_scapes/segnet/models/segnet_upsample_best.pth'

In [None]:
model = SegNet(39).to(device)
model.load_state_dict(torch.load(path_segnet_unweighted))

In [None]:
img, gt, pred = visualize(model, test_dataset)

In [None]:
img, gt, pred, out_channel1 = visualize(model, test_dataset)


In [None]:
image = pred
#h,w = img.shape
#h = int(h/2)
#w = int(w/2)

#print(img[h,w])
#x = img[h,w]

h,w = image.shape

for i in range(0,h):
  for j in range(0,w):
    if(image[i,j] != 1):
      image[i,j] = 0
    else :
      image[i,j] = 255*image[i,j]

#plt.imshow(image, label='Output of U-SegNet')

In [None]:
img.shape

In [None]:
fig = plt.figure(figsize=(20,60))
ax_top_1 = fig.add_subplot(131)
ax_top_1.set_axis_off()
plt.imshow(img, label='Input MRI')
ax_top_2 = fig.add_subplot(132)
ax_top_2.set_axis_off()
plt.imshow(gt, label='Ground truth')
ax_top_3 = fig.add_subplot(133)
ax_top_3.set_axis_off()
plt.imshow(image, label='Output of U-SegNet')
plt.show()
fig.savefig('/content/drive/MyDrive/Cityscapes/final_train/segnet.png', bbox_inches = 'tight', pad_inches = 0)

In [None]:
import matplotlib.pyplot as plt

# assuming out_channel1 is a tensor with shape [1, 256, 256]
out_channel1 = out_channel1[0].detach().cpu().numpy()

# display the image
plt.imshow(out_channel1)
plt.show()


In [None]:
fig = plt.figure(figsize=(20,60))
ax_top_1 = fig.add_subplot(131)
ax_top_1.set_axis_off()
plt.imshow(img, label='Input MRI')
ax_top_2 = fig.add_subplot(132)
ax_top_2.set_axis_off()
plt.imshow(gt, label='Ground truth')
ax_top_3 = fig.add_subplot(133)
ax_top_3.set_axis_off()
plt.imshow(image, label='Output of U-SegNet')
plt.show()
fig.savefig('/content/drive/MyDrive/Cityscapes/final_train/segnet.png', bbox_inches = 'tight', pad_inches = 0)

In [None]:
fig = plt.figure(figsize=(20,60))
ax_top_1 = fig.add_subplot(131)
ax_top_1.set_axis_off()
plt.imshow(img, label='Input MRI')
ax_top_2 = fig.add_subplot(132)
ax_top_2.set_axis_off()
plt.imshow(gt, label='Ground truth')
ax_top_3 = fig.add_subplot(133)
ax_top_3.set_axis_off()
plt.imshow(image, label='Output of U-SegNet')
plt.show()
fig.savefig('/content/drive/MyDrive/Cityscapes/final_train/segnet.png', bbox_inches = 'tight', pad_inches = 0)

In [None]:
fig = plt.figure(figsize=(20,60))
ax_top_1 = fig.add_subplot(131)
ax_top_1.set_axis_off()
plt.imshow(img, label='Input MRI')
ax_top_2 = fig.add_subplot(132)
ax_top_2.set_axis_off()
plt.imshow(gt, label='Ground truth')
ax_top_3 = fig.add_subplot(133)
ax_top_3.set_axis_off()
plt.imshow(image, label='Output of U-SegNet')
plt.show()
fig.savefig('/content/drive/MyDrive/Cityscapes/final_train/segnet.png', bbox_inches = 'tight', pad_inches = 0)

In [None]:
fig = plt.figure(figsize=(20,60))
ax_top_1 = fig.add_subplot(131)
ax_top_1.set_axis_off()
plt.imshow(img, label='Input MRI')
ax_top_2 = fig.add_subplot(132)
ax_top_2.set_axis_off()
plt.imshow(gt, label='Ground truth')
#ax_top_3 = fig.add_subplot(133)
#ax_top_3.set_axis_off()
#plt.imshow(image, label='Output of U-SegNet')
plt.show()
fig.savefig('/content/drive/MyDrive/Cityscapes/final_train/segnet.png', bbox_inches = 'tight', pad_inches = 0)

In [None]:
fig = plt.figure(figsize=(20,60))
ax_top_1 = fig.add_subplot(131)
ax_top_1.set_axis_off()
plt.imshow(img, label='Input MRI')
ax_top_2 = fig.add_subplot(132)
ax_top_2.set_axis_off()
plt.imshow(gt, label='Ground truth')
#ax_top_3 = fig.add_subplot(133)
#ax_top_3.set_axis_off()
#plt.imshow(image, label='Output of U-SegNet')
plt.show()
fig.savefig('/content/drive/MyDrive/Cityscapes/final_train/segnet.png', bbox_inches = 'tight', pad_inches = 0)

In [None]:
Sfrom PIL import Image

# Open the image file
image = Image.open('/content/drive/MyDrive/C_scapes/img/subset_val/img1.BMP')

# Resize the image to the desired dimensions
resized_image = image.resize((2048, 1024))

# Save the resized image
resized_image.save('/content/drive/MyDrive/C_scapes/img/subset_val/img1.BMP')
