In [1]:
import torch
#import albumentations as A
#from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import (
    load_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
    mean_IOU_score,
    mean_dice_score
)
import os
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import pandas as pd
import torchvision.transforms as T
from PIL import Image
import cv2
import matplotlib.pyplot as plt

In [2]:
# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda:3"
BATCH_SIZE =1
NUM_EPOCHS = 150
NUM_WORKERS = 0
IMAGE_HEIGHT = 608# 1280 originally
IMAGE_WIDTH = 960  # 1920 originally
PIN_MEMORY = True
LOAD_MODEL = False
IMG_DIR = r"E:\Thesis-Rishav\A2D2_dataset\images"    #Have to set this
MASK_DIR = r"E:\Thesis-Rishav\A2D2_dataset\seg_label"
TRAIN_CSV_FILE= r"E:\Thesis-Rishav\A2D2_dataset\train.csv"
TEST_CSV_FILE= r"E:\Thesis-Rishav\A2D2_dataset\test.csv"
#WRITER = True  # Controlling the tensorboard

In [62]:

def visualize_result(preds, batch_idx):
    transform= T.ToPILImage()
    preds = torch.softmax(preds, dim=1)
    preds = torch.argmax(preds, dim=1)
    preds= (preds.float())*12
    #print(class_mask)
    img= transform(preds).convert('RGB')
    #img= cv2.UMat(img)
    image_name= r'E:\Thesis-Rishav\Baselines\Semantic_Segmentation\saved_images\Image_{}.png'.format(batch_idx)
    #print('result_saved')
    img.save(image_name)
    
def visualize_target(targets, batch_idx):
    transform= T.ToPILImage()
    targets= (targets.float())*12
    #print(class_mask)
    img= transform(targets).convert('RGB')
    #img= cv2.UMat(img)
    image_name= r'E:\Thesis-Rishav\Baselines\Semantic_Segmentation\saved_targets\Image_{}.png'.format(batch_idx)
    img.save(image_name)
    #print('result_saved')   
def inference_fn(test_loader,model, loss_fn, scaler):
    loop = tqdm(test_loader)
    mean_loss=[]
    IOU_score= []
    Dice_score = []
    model.eval()
    for batch_idx, (data, targets) in enumerate(loop):
        transform= T.ToPILImage()
        image= transform(np.squeeze(data)).convert('RGB')
        image_name= r'E:\Thesis-Rishav\Baselines\Semantic_Segmentation\input_images\Image_{}.png'.format(batch_idx)
        image.save(image_name)
        data = data.to(device=DEVICE)

        visualize_target(targets, batch_idx)
        targets = targets.long().to(device=DEVICE)
        
        #targets = targets.long()
        #print(type(targets))
        #pop= (targets.float())*12
        #visualize_result(pop)
        #print(data.shape)
        #print(targets.shape)
        with torch.no_grad():
            predictions = model(data.float())

            iou_score=  mean_IOU_score(predictions, targets )

            dice_score= mean_dice_score(predictions, targets )
            accuracy=check_accuracy(predictions, targets,model)
            IOU_score.append(iou_score)
            Dice_score.append(dice_score)
            visualize_result(predictions,batch_idx)
        #loop.set_postfix(loss=loss.item())
    #mean_loss_value = sum(mean_loss) / len(mean_loss)
    mean_IOU= sum(IOU_score)/len(IOU_score)
    mean_DICE= sum(Dice_score)/len(Dice_score)
    model.train()
   
    return  mean_IOU, mean_DICE



    

In [65]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
test_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),

        ToTensorV2(),
    ],
)


model = UNET(in_channels=3, out_channels=22).to(DEVICE)
model.load_state_dict(torch.load(r"E:\Thesis-Rishav\Baselines\Semantic_Segmentation\Trained_parameters\model_60.pth"))
loss_fn= nn.CrossEntropyLoss()

test_loader = get_loaders(
        IMG_DIR,
        MASK_DIR,
        TEST_CSV_FILE,
        BATCH_SIZE,
        NUM_WORKERS,
        PIN_MEMORY,
        transform= test_transform
    )

scaler = torch.cuda.amp.GradScaler()

In [66]:
mean_IOU, mean_DICE= inference_fn(test_loader, model,loss_fn, scaler)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2728/2728 [46:05<00:00,  1.01s/it]


In [67]:

print(mean_IOU)
print(mean_DICE)

[0.8026299  0.4600761  0.38752896 0.3430002  0.67108628 0.6877569
 0.46200335 0.7585434  1.         0.92890762 0.58300243 0.38274481
 0.92965755 0.17415645 0.33960791 0.8118996  0.38615553 0.50660771
 0.51169588 0.93097461 0.64152301 0.25421066]
tensor([0.8606, 0.4982, 0.4395, 0.3873, 0.6777, 0.7270, 0.5770, 0.7823, 1.0000,
        0.9615, 0.5935, 0.5218, 0.9297, 0.2281, 0.4474, 0.8675, 0.4203, 0.6026,
        0.5259, 0.9506, 0.7242, 0.3618])


In [68]:
test_mIOU= sum(mean_IOU)/len(mean_IOU)
test_DICE= sum(mean_DICE)/len(mean_DICE)

In [69]:
print(test_mIOU)
print(test_DICE)

0.5888076756002929
tensor(0.6402)
