# Imports

### Installations

In [None]:
!pip install segmentation-models-pytorch
!pip install torchmetrics

### Imports

In [None]:
from pycocotools.coco import COCO
import numpy as np
import os
from IPython.display import clear_output
import segmentation_models_pytorch as smp
import torchvision
import torch
import torchmetrics as tm
import PIL
import random
import cv2
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms

### For visualizing the outputs ###
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline


# Loading Data

### Functions for loading COCO dataset

In [None]:
def filterDataset(folder, classes=None, mode='train'):    
    # initialize COCO api for instance annotations
    annFile = '{}/annotations/coco-{}.json'.format(folder, mode)
    # annFile = '{}/annotations/coco.json'.format(folder)
    coco = COCO(annFile)
    
    images = []
    if classes!=None:
        # iterate for each individual class in the list
        for className in classes:
            # get all images containing given categories
            catIds = coco.getCatIds(catNms=className)
            imgIds = coco.getImgIds(catIds=catIds)
            images += coco.loadImgs(imgIds)
    
    else:
        imgIds = coco.getImgIds()
        images = coco.loadImgs(imgIds)
    
    # Now, filter out the repeated images
    unique_images = []
    for i in range(len(images)):
        if images[i] not in unique_images:
            unique_images.append(images[i])
            
    random.shuffle(unique_images)
    dataset_size = len(unique_images)
    
    return unique_images, dataset_size, coco

def getClassName(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return None

def getImage(imageObj, img_folder, input_image_size):
    # Read and normalize an image
    # train_img = io.imread(img_folder + '/' + imageObj['file_name'])/255.0
    train_img = torchvision.io.read_image(img_folder + '/' + imageObj['file_name'])/255.0
    # train_img = io.imread(img_folder + '/' + imageObj['file_name'])
    # train_img = preprocess_input(train_img)
    # train_img = tf.keras.applications.resnet50.preprocess_input(train_img)
    # Resize
    # train_img = cv2.resize(train_img, input_image_size)
    train_img = torchvision.transforms.Resize(size=input_image_size)(train_img)
    if (len(train_img.shape)==3 and train_img.shape[0]==3): # If it is a RGB 3 channel image
        return train_img
    else: # To handle a black and white image, increase dimensions to 3
        stacked_img = np.stack((train_img,)*3, axis=-1)
        return stacked_img
    
def getNormalMask(imageObj, classes, coco, catIds, input_image_size):
    annIds = coco.getAnnIds(imageObj['id'], catIds=catIds, iscrowd=None)
    anns = coco.loadAnns(annIds)
    cats = coco.loadCats(catIds)
    train_mask = np.zeros(input_image_size)
    for a in range(len(anns)):
        className = getClassName(anns[a]['category_id'], cats)
        pixel_value = classes.index(className)+1
        new_mask = cv2.resize(coco.annToMask(anns[a])*pixel_value, input_image_size)
        train_mask = np.maximum(new_mask, train_mask)

    # Add extra dimension for parity with train_img size [X * X * 3]
    train_mask = train_mask.reshape(1, input_image_size[0], input_image_size[1])
    return train_mask  
    
def getBinaryMask(imageObj, coco, catIds, input_image_size):
    annIds = coco.getAnnIds(imageObj['id'], catIds=catIds, iscrowd=None)
    anns = coco.loadAnns(annIds)
    train_mask = np.zeros(input_image_size)
    for a in range(len(anns)):
        new_mask = cv2.resize(coco.annToMask(anns[a]), input_image_size)
        
        #Threshold because resizing may cause extraneous values
        new_mask[new_mask >= 0.5] = 1
        new_mask[new_mask < 0.5] = 0

        train_mask = np.maximum(new_mask, train_mask)

    # Add extra dimension for parity with train_img size [X * X * 3]
    train_mask = train_mask.reshape(input_image_size[0], input_image_size[1], 1)
    return train_mask


def dataGeneratorCoco(images, classes, coco, folder, 
                      input_image_size=(224,224), batch_size=4, mode='train', mask_type='binary'):
    
    img_folder = '{}/images/{}'.format(folder, mode)
    # img_folder = '{}/images'.format(folder)
    dataset_size = len(images)
    catIds = coco.getCatIds(catNms=classes)
    
    c = 0
    while(True):
        img = np.zeros((batch_size, 3, input_image_size[0], input_image_size[1])).astype('float')
        mask = np.zeros((batch_size, 1, input_image_size[0], input_image_size[1])).astype('float')

        for i in range(c, c+batch_size): #initially from 0 to batch_size, when c = 0
            imageObj = images[i]
            
            ### Retrieve Image ###
            train_img = getImage(imageObj, img_folder, input_image_size)
            # print(train_img.shape)

            ### Create Mask ###
            if mask_type=="binary":
                train_mask = getBinaryMask(imageObj, coco, catIds, input_image_size)
            
            elif mask_type=="normal":
                train_mask = getNormalMask(imageObj, classes, coco, catIds, input_image_size)                
            
            # Add to respective batch sized arrays
            img[i-c] = train_img
            mask[i-c] = train_mask
            
        c+=batch_size
        if(c + batch_size >= dataset_size):
            c=0
            random.shuffle(images)
        yield img, mask

In [None]:
folder = '/content/drive/MyDrive/COCOdatasettomato'
# classes = ['laptop', 'tv', 'cell phone']
# classes = None
classes = ['Leaf', 'Leaf_Diseased', 'Background']
mode = 'train'

images, dataset_size, coco = filterDataset(folder, classes, mode)
catIds = coco.getCatIds(catNms=classes)


print(images)
print(dataset_size)
print(coco)

### Load into arrays

In [None]:
input_image_size = (512,512)

train_images = []
train_masks = []
val_images = []
val_masks = []


'''
train images
'''
images, dataset_size, coco = filterDataset(folder, classes, mode='train')

# train_images = torch.empty(size=(len(images), 3, input_image_size[0], input_image_size[0]))
train_images = np.empty((len(images), 3, input_image_size[0], input_image_size[0]), dtype=np.uint8)
# train_masks = torch.empty(size=(len(images), 3, input_image_size[0], input_image_size[0]))
train_masks = np.empty((len(images), 3, input_image_size[0], input_image_size[0]), dtype=np.uint8)

for i, img_json in enumerate(images):
  filename = img_json['file_name']
  path = folder+'/images/train'
  img = torchvision.io.read_image(path + '/' + filename)
  img = torchvision.transforms.functional.resize(img, input_image_size)
  mask = getNormalMask(img_json, classes, coco, catIds, input_image_size) 
  mask = mask.astype(int)
  rgb_mask = np.zeros((3, 512, 512), dtype=np.uint8)
  rgb_mask[0] = (mask == 0) * 255
  rgb_mask[1] = (mask == 1) * 255
  rgb_mask[2] = (mask == 2) * 255
  mask = rgb_mask
  train_images[i] = img.detach().clone()
  train_masks[i] = torch.tensor(mask)
'''
val images
'''
images, dataset_size, coco = filterDataset(folder, classes, mode='val')

# val_images = torch.empty(size=(len(images), 3, input_image_size[0], input_image_size[0]))
val_images = np.empty((len(images), 3, input_image_size[0], input_image_size[0]), dtype=np.uint8)
# val_masks = torch.empty(size=(len(images), 3, input_image_size[0], input_image_size[0]))
val_masks = np.empty((len(images), 3, input_image_size[0], input_image_size[0]), dtype=np.uint8)

for i, img_json in enumerate(images):
  filename = img_json['file_name']
  path = folder+'/images/val'
  img = torchvision.io.read_image(path + '/' + filename)
  img = torchvision.transforms.functional.resize(img, input_image_size)
  mask = getNormalMask(img_json, classes, coco, catIds, input_image_size) 
  mask = mask.astype(int)
  rgb_mask = np.zeros((3, 512, 512), dtype=np.uint8)
  rgb_mask[0] = (mask == 0) * 255
  rgb_mask[1] = (mask == 1) * 255
  rgb_mask[2] = (mask == 2) * 255
  mask = rgb_mask
  val_images[i] = img.detach().clone()
  val_masks[i] = torch.tensor(mask)


### Visualize Mask

In [None]:
plt.imshow(np.transpose(train_images[0], (2,1,0)))
plt.show()
plt.imshow(np.transpose(train_masks[0], (2,1,0)))
plt.show()

### Dataset Class

In [None]:
class COCOdataset(torch.utils.data.Dataset):

    def __init__(self, image_tensor, mask_tensor, transform=None):
        self.image_tensor = image_tensor
        self.mask_tensor = mask_tensor
        self.transform = transform

    def __len__(self):
        return len(self.image_tensor)

    def __getitem__(self, idx):
        image = self.image_tensor[idx]
        mask = self.mask_tensor[idx]

        image = PIL.Image.fromarray(np.transpose(image, (2,1,0)))
        mask = PIL.Image.fromarray(np.transpose(mask, (2,1,0)))

        # seed = np.random.randint(2147483647)  # generate a random seed
        # random.seed(seed)
        # torch.manual_seed(0)

        image, mask = self.transform(image, mask)

        '''
        remove one hot
        '''
        mask = torch.argmax(mask, dim=0).unsqueeze(0)
        # image = np.transpose(image, (1,0,2))
        # mask = np.transpose(mask, (1,0,2))
        # print(image.shape, mask.shape)
        return image, mask

### Load into Dataset Class

In [None]:
batch_size = 8

'''
define transformations
'''
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    transforms.RandomVerticalFlip(),
    transforms.RandomAdjustSharpness(0.5),
    transforms.RandomRotation(360),
    # transforms.RandomRotation(360, fill=(255,0,0)),
    # torchvision.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5]),
    transforms.ToTensor(),           # Convert the image to a tensor
])

train_dataset = COCOdataset(train_images, train_masks, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = COCOdataset(val_images, val_masks, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

### Visualize

In [None]:
images, masks = next(iter(train_loader))
for i, (image, mask) in enumerate(zip(images,masks)):
  # mask = torch.argmax(mask, dim=0).unsqueeze(0)
  plt.imshow(np.transpose(image, (2,1,0)))
  plt.show()
  plt.imshow(np.transpose(mask, (2,1,0)))
  plt.show()
  print(np.unique(mask))

# print(len(train_loader))
# for i, (images, masks) in enumerate(iter(train_loader)):
#   plt.imshow(np.transpose(images[0], (2,1,0)))
#   plt.show()
#   plt.imshow(np.transpose(masks[0], (2,1,0)))
#   plt.show()

# Model

### Helper Functions and Assign Device

In [None]:
# check GPU availability
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(device))
else:
    device = None
    print('GPU is not available')

# def get_metrics(preds, masks):
#   tp, fp, fn, tn = smp.metrics.get_stats(preds.unsqueeze(1), masks, mode='multiclass', num_classes=3)

#   precision = sum(tp) / (sum(tp) + sum(fp))
#   recall = sum(tp) / (sum(tp) + sum(fn))
#   accuracy = (sum(tp) + sum(tn)) / (sum(tp)+sum(fp)+sum(tn)+sum(fn))
#   f1 = 2*precision*recall / (precision + recall)
  
#   return (precision, recall, accuracy, f1)

def get_metrics(tp, fp, fn, tn):

  precision = (tp) / ((tp) + (fp))
  recall = (tp) / ((tp) + (fn))
  accuracy = ((tp) + (tn)) / ((tp)+(fp)+(tn)+(fn))
  f1 = 2*precision*recall / (precision + recall)
  
  return (precision, recall, accuracy, f1)

def get_stats(preds, masks):
  tp, fp, fn, tn = smp.metrics.get_stats(preds.unsqueeze(1), masks, mode='multiclass', num_classes=3)
  
  return (tp, fp, fn, tn)

def print_metrics(p,r,a,f1):
  for i, (_p,_r,_a,_f1) in enumerate(zip(p,r,a,f1)):
    if torch.any(torch.isnan(_p)) or torch.any(torch.isnan(_r)) or torch.any(torch.isnan(_a)) or torch.any(torch.isnan(_f1)): 
      break
    print('   P       R       A       F1')
    print(f'{i}: {_p:.4f}   {_r:.4f}   {_a:.4f}   {_f1:.4f}')

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

def show_predictions(images, true_masks, pred_masks):
    """
    Show the original image, true mask, and predicted mask for each image in the batch.
    
    Args:
    images: tensor of shape (batch_size, channels, height, width) representing the batch of images
    true_masks: tensor of shape (batch_size, 1, height, width) representing the true masks for the batch of images
    pred_masks: tensor of shape (batch_size, 1, height, width) representing the predicted masks for the batch of images
    """

    # Convert the tensors to numpy arrays
    images = images.cpu().numpy()
    true_masks = true_masks.cpu().numpy()
    pred_masks = pred_masks.cpu().numpy()

    # Loop over the images in the batch
    for i in range(images.shape[0]):
        # Create a new figure
        fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10, 5))
        
        # Show the original image
        axes[0].imshow(np.transpose(images[i], (1, 2, 0)))
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Show the true mask
        axes[1].imshow(np.transpose(true_masks[i], (1, 2, 0)), cmap='gray')
        axes[1].set_title('True Mask')
        axes[1].axis('off')
        
        # Show the predicted mask
        axes[2].imshow(pred_masks[i], cmap='gray')
        axes[2].set_title('Predicted Mask')
        axes[2].axis('off')
        
        # Show the figure
        plt.show()


### Training

In [None]:
train_size = len(train_dataset)
val_size = len(val_dataset)
steps_per_epoch = train_size // batch_size
validation_steps = val_size // batch_size

'''
Define Model
'''
import torch.nn as nn
import torch.optim as optim
import torch
from segmentation_models_pytorch.encoders import get_preprocessing_fn


# create model
# model = smp.UnetPlusPlus(encoder_name='resnet34', encoder_weights='imagenet', in_channels=3, classes=3)
# model = smp.UnetPlusPlus(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=3)
model = smp.UnetPlusPlus(encoder_name='mobilenet_v2', encoder_weights='imagenet', in_channels=3, classes=3)

model = model.to(device=device)

preprocess_input = get_preprocessing_fn('mobilenet_v2', pretrained='imagenet')
'''
Hyper Parameters
'''
# optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001) # 0.001
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
# Define your loss function
# loss_function = nn.CrossEntropyLoss().to(device)
# loss_function = nn.CrossEntropyLoss(weight=torch.tensor([0.2,0.2,0.6])).to(device)
loss_function = smp.losses.JaccardLoss(mode='multiclass').to(device)
# loss_disease = smp.losses.JaccardLoss(mode='multiclass', classes=[2]).to(device)
# loss_other = smp.losses.JaccardLoss(mode='multiclass', classes=[0,1]).to(device)
# def lf(preds, true):
#   return 2*loss_disease(preds,true) + loss_other(preds,true)
# loss_function = lf

'''
Take sample
'''
sample_images, sample_masks = next(iter(val_loader))
# sample_images = torch.from_numpy(sample_images)
# sample_masks = torch.from_numpy(sample_masks)
sample_images = sample_images.to(device=device)
sample_masks = sample_masks.to(device=device)

'''
Lists for Plots
'''
train_loss_list = []
val_loss_list = []
train_p_list = []
train_r_list = []
val_p_list = []
val_r_list = []

'''
Training
'''
num_epochs = 80
steps_per_epoch = 32
for epoch in range(num_epochs):
    train_tp = torch.tensor([0,0,0])
    train_fp = torch.tensor([0,0,0])
    train_tn = torch.tensor([0,0,0])
    train_fn = torch.tensor([0,0,0])
    val_tp = torch.tensor([0,0,0])
    val_fp = torch.tensor([0,0,0])
    val_tn = torch.tensor([0,0,0])
    val_fn = torch.tensor([0,0,0])
    print('Epoch: ' + str(epoch))
    train_loss = 0.0
    val_loss = 0.0
    model.train()
    for i, (images, masks_) in enumerate(iter(train_loader)):
        # print(str(step) + ' of ' + str(steps_per_epoch))
        # masks = masks_.squeeze(1)
        masks = masks_
        # masks = torch.from_numpy(masks).long()
        # masks = torch.argmax(masks, dim=1)  # Convert one-hot masks to integer indices
        images = images.to(device=device)
        masks = masks.to(device=device)
        optimizer.zero_grad()
        outputs = model(images)

        loss = loss_function(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)

        probs = torch.softmax(outputs, dim=1)
        _, labels = torch.max(probs, dim=1)

        tp, fp, fn, tn = get_stats(labels, masks)
        train_tp += tp.sum(dim=0)
        train_fp += fp.sum(dim=0)
        train_fn += fn.sum(dim=0)
        train_tn += tn.sum(dim=0)
    p, r, a, f1 = get_metrics(train_tp, train_fp, train_fn, train_tn)
    print_metrics(p, r, a, f1)
    model.eval()
    with torch.no_grad():
        for i, (images, masks_) in enumerate(iter(val_loader)):
            # masks = masks_.squeeze(1)
            masks = masks_
            # masks = torch.from_numpy(masks).long()
            # masks = torch.argmax(masks, dim=1)  # Convert one-hot masks to integer indices
            images = images.to(device=device)
            masks = masks.to(device=device)
            outputs = model(images)

            loss = loss_function(outputs, masks)
            val_loss += loss.item() * images.size(0)
            
            probs = torch.softmax(outputs, dim=1)
            _, labels = torch.max(probs, dim=1)
            
            tp, fp, fn, tn = get_stats(labels, masks)
            val_tp += tp.sum(dim=0)
            val_fp += fp.sum(dim=0)
            val_fn += fn.sum(dim=0)
            val_tn += tn.sum(dim=0)
    val_loss /= ((val_size // batch_size)*2)
    # update learning rate
    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, train_loss: {train_loss}, val_loss: {val_loss}")
    p, r, a, f1 = get_metrics(val_tp, val_fp, val_fn, val_tn)
    print_metrics(p, r, a, f1)
    '''
    append to lists
    '''
    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)

    outputs = model(sample_images)
    # Apply softmax activation function to the output
    probs = torch.softmax(outputs, dim=1)
    # Get the predicted labels
    _, labels = torch.max(probs, dim=1)
    show_predictions(sample_images, sample_masks, labels)

    # plt.show()


In [None]:
import matplotlib.pyplot as plt

# Plot the train and validation loss curves
plt.plot(train_loss_list, label='Train Loss')
plt.plot(val_loss_list, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss')
plt.legend()
plt.show()

In [None]:
model.eval()
images, masks = next(iter(val_loader))
images = images.to(device=device)
masks = masks.to(device=device)
outputs = model(images)
# Apply softmax activation function to the output
probs = torch.softmax(outputs, dim=1)
# Get the predicted labels
_, labels = torch.max(probs, dim=1)

show_predictions(images, masks, labels)

In [None]:
torch.save(model, '/content/drive/MyDrive/saved_seg_models/mobilenetv3-80ep')

# Evaluate

In [None]:
eval_model = torch.load('/content/drive/MyDrive/saved_seg_models/mobilenetv2.3')
eval_model.train(mode=False)
eval_model = eval_model.to(device)

In [None]:
# total_tp = torch.from_numpy(np.zeros_like([
#     [0,0,0],
#     [0,0,0],
#     [0,0,0],
#     [0,0,0],
#     [0,0,0],
#     [0,0,0],
#     [0,0,0],
#     [0,0,0]
# ]))
# total_fp = torch.from_numpy(np.zeros_like(total_tp))
# total_tn = torch.from_numpy(np.zeros_like(total_tp))
# total_fn = torch.from_numpy(np.zeros_like(total_tp))
total_tp = 0
total_fp = 0
total_tn = 0
total_fn = 0

for i, (images, masks) in enumerate(iter(train_loader)):
  # if i > 0:
  #   break
  images = images.to(device=device)
  masks = masks.to(device=device)
  outputs = eval_model(images)
  # Apply softmax activation function to the output
  probs = torch.softmax(outputs, dim=1)
  # Get the predicted labels
  _, labels = torch.max(probs, dim=1)

  tp, fp, fn, tn = get_stats(labels, masks)
  print(tp)
  if tp.size(0) < batch_size:
    diff = batch_size - tp.size(0)
    bigger = torch.zeros((tp.size(0)+diff, 3)).long()
    bigger[:tp.shape[0], :] = tp
    tp = bigger
    bigger = torch.zeros((fp.size(0)+diff, 3)).long()
    bigger[:fp.shape[0], :] = fp
    fp = bigger
    bigger = torch.zeros((fn.size(0)+diff, 3)).long()
    bigger[:fn.shape[0], :] = fn
    fn = bigger
    bigger = torch.zeros((tn.size(0)+diff, 3)).long()
    bigger[:tn.shape[0], :] = tn
    tn = bigger
  total_tp += tp
  total_fp += fp
  total_fn += fn
  total_tn += tn
  # p,r,a,f1 = get_metrics(tp,fp,fn,tn)
  # print(p,r,a,f1)

  # show_predictions(images, masks, labels)
  
  # iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
  # f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
  # f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
  # accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
  # recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")

In [None]:
# precision = total_tp / (total_tp + total_fp)
# recall = total_tp / (total_tp + total_fn)
precision = smp.metrics.precision(total_tp[:,2], total_fp[:,2], total_fn[:,2], total_tn[:,2], reduction='macro')
recall = smp.metrics.recall(total_tp[:,2], total_fp[:,2], total_fn[:,2], total_tn[:,2], reduction='macro')
precision, recall