In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from pycocotools.coco import COCO
import numpy as np
import cv2
from matplotlib import pyplot as plt

ModuleNotFoundError: No module named 'pycocotools'

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
def imshow(img):
    """function to show an image"""
    img = img/2 + 0.5
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
    plt.show()
    
def display_data_sample(img, target):
    """function to display one data point"""
    # display the image
    imshow(img)
    print()
    # check for annotations
    if len(target['boxes']):
        # add all the masks and display
        mask = np.sum(target['masks'].numpy(), axis=0).astype(np.uint8)
        plt.imshow(mask, cmap='gray', vmin=0, vmax=255)
        plt.show()
    else:
        print('no annotations')

In [None]:
class TILDetectionDataset(Dataset):
    """COCO Detection format dataset for TIL detection"""
    
    def __init__(self, root, annFile):
        self.root = root
        self.annFile = annFile
        self.coco = COCO(annFile)
    
    def __len__(self):
        """return the number of images"""
        return len(self.coco.imgs)
    
    def __getitem__(self, idx):
        """
            returns (img, target)
            img    : normalized image tensor
            target : dict with keys :
                            boxes (FloatTensor[N, 4])
                            labels (Int64Tensor[N])
                            masks (UInt8Tensor[N, H, W])
        """
        
        # read the image
        filename = self.root + self.coco.imgs[idx]['file_name']
        img = torchvision.io.read_image(filename)
        img = torch.as_tensor(img, dtype=torch.float)
        img = img/255
        _, H,W = img.shape
        # normalize the image
        img = torchvision.transforms.functional.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        # get the annotations in img
        annIds = self.coco.getAnnIds(imgIds=idx)
        anns = self.coco.loadAnns(annIds)
        # generate boxes
        boxes = []
        for ann in anns:
            xmin = ann['bbox'][0]
            ymin = ann['bbox'][1]
            xmax = xmin + ann['bbox'][2]
            ymax = ymin + ann['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # generate labels
        labels = torch.ones((len(anns)), dtype=torch.int64)
        # create blank image for mask
        masks = np.zeros((len(anns), H, W), dtype=np.uint8)
        # generate masks
        for idx, ann in enumerate(anns):
            xmin = ann['bbox'][0]
            ymin = ann['bbox'][1]
            xmax = xmin + ann['bbox'][2]
            ymax = ymin + ann['bbox'][3]
            masks[idx] = cv2.rectangle(masks[idx], (xmin, ymin), (xmax, ymax), 255,-1)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        # craete dictionary
        retDict = {
            'boxes':boxes,
            'labels':labels,
            'masks':masks}
        return (img, retDict)

In [None]:
# create dataset
dataset = TILDetectionDataset('/data/wsirois/roi-level-annotations/tissue-cells/', '/data/wsirois/roi-level-annotations/tissue-cells/tiger-coco.json')

# create train and test split from the data
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [None]:
img, target = train_dataset[0]
display_data_sample(img, target)

In [None]:
# create the model
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, num_classes=2, pretrained_backbone=True)
model.to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr= 3e-4, weight_decay=0.01)

train_losses = []

for epoch in range(10):  # loop over the dataset multiple times
    
    running_train_loss = 0.0
    running_test_loss = 0.0

    with tqdm(train_dataset, unit="batch") as tepoch:
        for i, (img, target) in enumerate(tepoch, 0):
            tepoch.set_description(f"Epoch {epoch}")
            
            # get the inputs; data is a list of [inputs, labels]
            if len(target['boxes']) > 0:
                for key in target.keys():
                    target[key] = target[key].to(device)
                
                img = img.to(device)
                
                target = [target]
                img = [img]
            
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = model(img, target)
                loss = outputs['loss_classifier'].item() + outputs['loss_box_reg'].item() + outputs['loss_mask'].item() + outputs['loss_objectness'].item() + outputs['loss_rpn_box_reg']
                loss.backward()
                optimizer.step()
                
                # detatch the loss from GPU
                loss = loss.detach().cpu()
                
                # track statistics
                #running_train_loss += loss.item()
                #tepoch.set_postfix(loss=loss.item())
                
                # discard variables to free GPU memory
                del outputs
                del loss
                del img
                for key in target[0].keys():
                    target[0][key] = target[0][key].detach().cpu()
                
                # empty the GPU cache
                torch.cuda.empty_cache()
    
    train_losses.append(running_train_loss/len(train_dataset))