<a href="https://colab.research.google.com/github/Waye/CSC420-CourseWork-fall2019/blob/master/Assignment3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1.Image Segmentation

In [0]:
# helper funtion
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from skimage import io, transform
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils

from nn import Net

use_gpu = torch.cuda.is_available()


class CatDataset(Dataset):

    def __init__(self, root_dir, pkl, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        with open(pkl, 'rb') as f:
            self.data = pickle.load(f)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        cat_name = os.path.join(self.root_dir, 'input/', self.data[index][0])
        seg_name = os.path.join(self.root_dir, 'mask/', self.data[index][1])
        cat = io.imread(cat_name)
        segmentation = io.imread(seg_name)
        if cat.shape[-1] != 3:
            print(cat.shape[-1])
            print(cat_name)
            exit()

        sample = {'cat': cat, 'mask': segmentation}

        if self.transform:
            sample = self.transform(sample)

        return sample


class Rescale(object):
    """Rescale the image in a sample to a given size.
    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        cat, mask = sample['cat'], sample['mask']

        h, w = cat.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        cat = transform.resize(cat, (new_h, new_w))

        h, w = mask.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        
        mask = transform.resize(mask, (new_h, new_w))
        mask[mask > 0.5] = 1
        mask[mask <= 0.5] = 0  ### FUUUUUUU
        return {'cat': cat, 'mask': mask}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        cat, mask = sample['cat'], sample['mask']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        cat = cat.transpose((2, 0, 1))
        mask = mask.reshape((128, 128, 1)).astype(np.float_)
        mask = mask.transpose((2, 0, 1))
        return {'cat': torch.from_numpy(cat),
                'mask': torch.from_numpy(mask)}



def step(net, cats, masks, criteria, optimizer):
    cats = cats.unsqueeze(0).float()
    predicted_mask = net(cats)
    loss = criteria(predicted_mask[0], masks.float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss


def train(net, epochs, batch_size, criteria, optimizer, trainset, testset):
    num_iter = len(trainset)
    best = 0.0
    for e in range(epochs):
        print(f'Beginning epoch {e+1}')
        net.train()
        loss = 0
        for i in range(num_iter):
            sample = trainset[i]
            cat = sample['cat']
            mask = sample['mask']
            if use_gpu:
                cat = cat.cuda()
                mask = mask.cuda()
            loss += step(net, cat, mask, criteria, optimizer)
        print(f"Epoch #{e+1}")
        print(f"Total loss: {loss}")
        net.eval()
        test_p = performance_sorensen(net, testset)
        # if test_p > best:  # If you want to save the model
            # best = test_p            
            # torch.save(net.state_dict(), f'q1/checkpoints_{e+1}.pt')


def calculate_sorensen(arr1, arr2):
    return np.sum(np.equal(arr1, arr2)) / np.size(arr1)


def performance_sorensen(net, dataset):
    scale = Rescale((128, 128))

    sorensen = 0
    for i in range(len(dataset)):
        sample = dataset[i]
        cat = sample['cat']
        mask = sample['mask']
        if use_gpu:
            cat = cat.cuda()
        predicted_mask = net(cat.unsqueeze(0).float())
        img = predicted_mask.detach().cpu().numpy()[0,0,:,:]
        img[img > 0.5] = 1
        img[img <= 0.5] = 0
        truth = mask.numpy()[0,:,:]
        coeff = calculate_sorensen(img, truth)
        sorensen += coeff
    print(f'Average sorensen-dice coefficient: {sorensen / len(dataset)}')
    return sorensen / len(dataset)


class DiceLoss(torch.nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, predict, target):
        predict = predict.contiguous()
        target = target.contiguous()

        intersect = (predict * target).sum(dim=1).sum(dim=1)
        
        loss = (1 - ((2. * intersect + 1.) / (predict.sum(dim=1).sum(dim=1) + target.sum(dim=1).sum(dim=1) + 1.)))

        return loss.mean()


class BCELoss(torch.nn.Module):
    def __init__(self):
        super(BCELoss, self).__init__()
    
    def forward(self, predict, target):
        return F.binary_cross_entropy(predict, target, reduction='mean')


def main(do_train):
    scale = Rescale((128, 128))
    trainset = CatDataset('cat_data/Train/', 'train.pkl', transform=transforms.Compose([scale, ToTensor()]))
    # trainset = CatDataset('cat_data/AugmentedTrain/', 'aug_train.pkl', transform=transforms.Compose([scale, ToTensor()]))
    # trainset = CatDataset('transfer/Train', 'transfer_train.pkl', transform=transforms.Compose([scale, ToTensor()]))
    testset = CatDataset('cat_data/Test/', 'test.pkl', transform=transforms.Compose([scale, ToTensor()]))

    net = Net()
    if use_gpu:
        net = net.cuda()

    batch_size = 1
    epochs = 50

    optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.99)
    criteria = BCELoss()
    # criteria = DiceLoss()

    train(net, epochs, batch_size, criteria, optimizer, trainset, testset)


if __name__ == '__main__':
    main(True)

## 1.1 Implement U-NET

In [0]:
# nn.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):

    def downsample_block(self, input_size, output_size):
        return torch.nn.Sequential(
            torch.nn.Conv2d(input_size, output_size, 3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.BatchNorm2d(output_size),
            torch.nn.Conv2d(output_size, output_size, 3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.BatchNorm2d(output_size),
        )

    def upsample_block(self, input_size, output_size):
        return torch.nn.Sequential(
            torch.nn.Conv2d(input_size, output_size, 3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.BatchNorm2d(output_size),
            torch.nn.Conv2d(output_size, output_size, 3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.BatchNorm2d(output_size),
        )

    def __init__(self):
        super(Net, self).__init__()

        # NOTE: (3, 64, 3) instead of (1, 64, 3) due to color images
        self.down_conv1 = self.downsample_block(3, 64)
        self.down_conv2 = self.downsample_block(64, 128)
        self.down_conv3 = self.downsample_block(128, 256)
        self.down_conv4 = self.downsample_block(256, 512)
        self.down_conv5 = self.downsample_block(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.up_conv4 = self.upsample_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512,256, 2, stride=2)
        self.up_conv3 = self.upsample_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up_conv2 = self.upsample_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.up_conv1 = self.upsample_block(128, 64)

        self.final_conv = torch.nn.Conv2d(64, 1, 1)

    def forward(self, x):
        skip1 = self.down_conv1(x)
        downsample1 = F.max_pool2d(skip1, 2)
        skip2 = self.down_conv2(downsample1)
        downsample2 = F.max_pool2d(skip2, 2)
        skip3 = self.down_conv3(downsample2)
        downsample3 = F.max_pool2d(skip3, 2)
        skip4 = self.down_conv4(downsample3)
        downsample4 = F.max_pool2d(skip4, 2)
        downsample5 = self.down_conv5(downsample4)
        upsample4 = self.up4(downsample5)
        upsample4 = torch.cat((skip4, upsample4), dim=1)
        upsample4 = self.up_conv4(upsample4)
        upsample3 = self.up3(upsample4)
        upsample3 = torch.cat((skip3, upsample3), dim=1)
        upsample3 = self.up_conv3(upsample3)
        upsample2 = self.up2(upsample3)
        upsample2 = torch.cat((skip2, upsample2), dim=1)
        upsample2 = self.up_conv2(upsample2)
        upsample1 = self.up1(upsample2)
        upsample1 = torch.cat((skip1, upsample1), dim=1)
        upsample1 = self.up_conv1(upsample1)
        x = self.final_conv(upsample1)
        return F.sigmoid(x)

## 1.2 Data Augmentation

In [4]:
# augment.py
import os
import numpy as np

from skimage import io, transform

AUG = 'cat_data/AugmentedTrain/'
TRAIN = 'cat_data/Train'

for f in os.listdir(TRAIN + '/input'):
    cat = io.imread(TRAIN + '/input/' + f)
    mask = io.imread(TRAIN + '/mask/mask_' + f)

    # Do a flip left-right
    # flip_cat = np.fliplr(cat)
    # flip_mask = np.fliplr(mask)
    
    # io.imsave(AUG + '/input/flip' + f, flip_cat)
    # io.imsave(AUG + '/mask/flip_mask_' + f, flip_mask)

    # Do an upside down flip
    # updown_cat = np.flipud(cat)
    # updown_mask = np.flipud(mask)
    # io.imsave(AUG + '/input/updown' + f, updown_cat)
    # io.imsave(AUG + '/mask/updown_mask_' + f, updown_mask)
    
    # Do a 90 degree rotation right
    # right_cat = transform.rotate(cat, 90, resize=True)
    # right_mask = transform.rotate(mask, 90, resize=True)
    # io.imsave(AUG + '/input/right' + f, right_cat)
    # io.imsave(AUG + '/mask/right_mask_' + f, right_mask)

    # Do a 90 degree rotation left
    # left_cat = transform.rotate(cat, 270, resize=True)
    # left_mask = transform.rotate(mask, 270, resize=True)
    # io.imsave(AUG + '/input/left' + f, left_cat)
    # io.imsave(AUG + '/mask/left_mask_' + f, left_mask)

    # Do a random crop)
    # print(cat.shape)
    # crop_len = np.random.randint(0, max(min(cat.shape[0] - 150, cat.shape[1] - 150), 1)) // 4
    # crop_cat = cat[crop_len:cat.shape[0] - crop_len, crop_len:cat.shape[1]-crop_len,:]
    # crop_mask = mask[crop_len:cat.shape[0] - crop_len, crop_len:cat.shape[1]-crop_len]
    # io.imsave(AUG + '/input/crop' + f, crop_cat)
    # io.imsave(AUG + '/mask/crop_mask_' + f, crop_mask)

    # Do a horizontal stretch
    # stretch_cat = transform.resize(cat, (cat.shape[0], int(cat.shape[1]*1.5)))
    # stretch_mask = transform.resize(mask, (mask.shape[0], int(mask.shape[1]*1.5)))
    # io.imsave(AUG + '/input/hstretch' + f, stretch_cat)
    # io.imsave(AUG + '/mask/hstretch_mask_' + f, stretch_mask)

    # stretch_cat = transform.resize(cat, (int(cat.shape[0] * 1.5), cat.shape[1]))
    # stretch_mask = transform.resize(mask, (int(mask.shape[0] * 1.5), mask.shape[1]))
    # io.imsave(AUG + '/input/vstretch' + f, stretch_cat)
    # io.imsave(AUG + '/mask/vstretch_mask_' + f, stretch_mask)

FileNotFoundError: ignored

## 1.3 Transfer Learning

## 1.4 Visualizing segmentation predictions

In [0]:
import argparse
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt
import glob
import cv2

from nn import Net
from skimage import io, feature, transform


use_gpu = torch.cuda.is_available()


def binary_mask(arr):
    arr[arr > 0.5] = 1
    arr[arr <= 0.5] = 0


def process_cat(cat):
    h, w = cat.shape[:2]
    output_size = (128, 128)
    if isinstance(output_size, int):
        if h > w:
            new_h, new_w = output_size * h / w, output_size
        else:
            new_h, new_w = output_size, output_size * w / h
    else:
        new_h, new_w = output_size

    new_h, new_w = int(new_h), int(new_w)

    sized_cat = transform.resize(cat, (new_h, new_w))
    sized_cat = sized_cat.transpose((2, 0, 1))
    cat_tensor = torch.from_numpy(sized_cat)
    return cat_tensor


def plot_seg(image, net):
    cat = io.imread(image)
    cat_tensor = process_cat(cat)
    
    if use_gpu:
        cat_tensor = cat_tensor.cuda()
    cat_mask = net(cat_tensor.unsqueeze(0).float())
    img = cat_mask.detach().cpu().numpy()[0,0,:,:]
    resized_mask = transform.resize(img, cat.shape[0:2], anti_aliasing=False, preserve_range=True)
    binary_mask(resized_mask)
    resized_mask = (resized_mask * 255).astype(np.uint8)
    ret, thresh = cv2.threshold(resized_mask, 127,255, 0)
    _, contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(cat, contours, -1, (0,255,0), 3)
    return cat


def plot_mask(image, net):
    cat = io.imread(image[0])
    cat_tensor = process_cat(cat)
    truth = io.imread(image[1])

    if use_gpu:
        cat_tensor = cat_tensor.cuda()
    cat_mask = net(cat_tensor.unsqueeze(0).float())
    mask = cat_mask.detach().cpu().numpy()[0,0,:,:]
    resized_mask = transform.resize(mask, cat.shape[0:2], anti_aliasing=False, preserve_range=True)
    binary_mask(resized_mask)
    return cat, truth, resized_mask


def main(images, mask=False):
    net = Net()
    if use_gpu:
        device = torch.device('cuda')
        net.load_state_dict(torch.load('q1_checkpoints/q1_2_dice_20.pt'))
        net.to(device)
    else:
        net.load_state_dict(torch.load('q1_checkpoints/q1_2_bce_35.pt', map_location='cpu'))
    
    if mask:
        fig = plt.figure(figsize=(16, 20))
        l = len(images) // 2 + 1
        for i in range(len(images) // 2):
            cat, truth, pred = plot_mask(images[2 * i], net)
            a = fig.add_subplot(l, 6, i*6 + 1)
            imgplot = plt.imshow(cat)
            if i == 0: a.set_title('Cat')
            a = fig.add_subplot(l, 6, i*6 + 2)
            imgplot = plt.imshow(truth, cmap='gray')
            if i == 0: a.set_title('Groundtruth')
            a = fig.add_subplot(l, 6, i*6 + 3)
            imgplot = plt.imshow(pred, cmap='gray')
            if i == 0: a.set_title('Prediction')
            cat, truth, pred = plot_mask(images[2 * i + 1], net)
            a = fig.add_subplot(l, 6, i*6 + 4)
            imgplot = plt.imshow(cat)
            if i == 0: a.set_title('Cat')
            a = fig.add_subplot(l, 6, i*6 + 5)
            imgplot = plt.imshow(truth, cmap='gray')
            if i == 0: a.set_title('Groundtruth')
            a = fig.add_subplot(l, 6, i*6 + 6)
            imgplot = plt.imshow(pred, cmap='gray')
            if i == 0: a.set_title('Prediction')
        if len(images) % 2 != 0:
            cat, truth, pred = plot_mask(images[-1], net)
            a = fig.add_subplot(l, 6, i*6 + 7)
            imgplot = plt.imshow(cat)
            a = fig.add_subplot(l, 6, i*6 + 8)
            imgplot = plt.imshow(truth, cmap='gray')
            a = fig.add_subplot(l, 6, i*6 + 9)
            imgplot = plt.imshow(pred, cmap='gray')
        plt.savefig('vis1_2_dice', bbox_inches='tight')
    else:
        i = 1
        for image in images:
            cat = plot_seg(image[0], net)
            plt.imsave(f'seg{i}.png', cat)
            i += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Visualize cat segmentations')
    parser.add_argument('root', type=str, help='path to where files are located')
    parser.add_argument('pickle', type=str, help='pickle file containing list of filenames')
    parser.add_argument('--mask', action="store_true", help='display mask instead of outline')

    args = parser.parse_args()
    f = open(args.pickle, 'rb')
    files = pickle.load(f)
    new_files = []
    for file in files:
        new_files.append((args.root + 'input/' + file[0], args.root + 'mask/' + file[1]))
    main(new_files, args.mask)

# 2.Bounding Box Design

## 2.1 Problem definition

## 2.2 Implementation

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        # Number of parameters:
        # (5 * 5 * 1 + 2) * 16 = 324
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Number of parameters
        # (5 * 5 * 16 + 2) * 32 = 12832
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Number of parameters
        # (3 * 3 * 32 + 1) * 64) = 18496
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Number of parameters:
        # (3 * 3 * 64 + 1) * 128 = 73856
        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Number of parameters
        # 12 * 12 * 128 * 3 = 55296
        self.fc = nn.Linear(12*12*128, 1)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

## 2.3 IOU Optimization

In [0]:
import gc

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from shapely.geometry.point import Point
from skimage.color import gray2rgb
from skimage.draw import circle, circle_perimeter_aa
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from nn_q2 import Net

use_gpu = torch.cuda.is_available()

class CircleDataset(Dataset):

    def __init__(self, size):
        self.size = size
        self.seeds = [np.random.randint(1, 2147483647) for i in range(size)]
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, index):
        seed = self.seeds[index]
        np.random.seed(seed)
        param, img = noisy_circle(200, 50, 2)
        np.random.seed(None)
        return img, param


def draw_circle(img, row, col, rad):
    rr, cc, val = circle_perimeter_aa(row, col, rad)
    valid = (
        (rr >= 0) &
        (rr < img.shape[0]) &
        (cc >= 0) &
        (cc < img.shape[1])
    )
    img[rr[valid], cc[valid]] = val[valid]


def noisy_circle(size, radius, noise):
    img = np.zeros((size, size), dtype=np.float)

    # Circle
    row = np.random.randint(size)
    col = np.random.randint(size)
    rad = np.random.randint(10, max(10, radius))
    draw_circle(img, row, col, rad)

    # Noise
    img += noise * np.random.rand(*img.shape)
    return (row, col, rad), img


def find_circle(img):
    # Load the CNNs
    netx = Net()
    nety = Net()
    netr = Net()
    device = torch.device('cuda')
    netx.load_state_dict(torch.load('q2x.pt'))
    nety.load_state_dict(torch.load('q2y.pt'))
    netr.load_state_dict(torch.load('q2r.pt'))
    netx.to(device)
    nety.to(device)
    netr.to(device)

    # Test the image
    img_tensor = img.reshape(200, 200, 1).transpose((2, 0, 1))
    tensor = torch.from_numpy(img_tensor)
    if use_gpu:
        tensor = tensor.cuda()
    tensor = tensor.unsqueeze(1).float()
    resx = netx(tensor).detach().cpu().numpy()[0]
    resy = nety(tensor).detach().cpu().numpy()[0]
    resr = netr(tensor).detach().cpu().numpy()[0]
    return int(resx), int(resy), int(resr)


def iou(params0, params1):
    row0, col0, rad0 = params0
    row1, col1, rad1 = params1

    shape0 = Point(row0, col0).buffer(rad0)
    shape1 = Point(row1, col1).buffer(rad1)

    return (
        shape0.intersection(shape1).area /
        shape0.union(shape1).area
    )


def observe():
    '''
    Plot the groundtruth bounding circle and the predicted bounding circle.
    The groundtruth is in green and the prediction in red.
    '''
    for _ in range(10):
        netx = Net()
        nety = Net()
        netr = Net()
        device = torch.device('cuda')
        netx.load_state_dict(torch.load('q2x.pt'))
        nety.load_state_dict(torch.load('q2y.pt'))
        netr.load_state_dict(torch.load('q2r.pt'))
        netx.to(device)
        nety.to(device)
        netr.to(device)

        params, img = noisy_circle(200, 50, 2)
        img_tensor = img.reshape(200, 200, 1).transpose((2, 0, 1))
        tensor = torch.from_numpy(img_tensor)
        tensor = tensor.unsqueeze(1).float()
        if use_gpu:
            tensor = tensor.cuda()
        resx = netx(tensor).detach().cpu().numpy()[0]
        resy = nety(tensor).detach().cpu().numpy()[0]
        resr = netr(tensor).detach().cpu().numpy()[0]
        img = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_GRAY2RGB)

        rr, cc, val = circle_perimeter_aa(int(resx), int(resy), int(resr))
        valid = (
            (rr >= 0) &
            (rr < img.shape[0]) &
            (cc >= 0) &
            (cc < img.shape[1])
        )

        img[rr[valid], cc[valid], 0] = 255
        img[rr[valid], cc[valid], 1] = 0
        img[rr[valid], cc[valid], 2] = 0

        rr, cc, val = circle_perimeter_aa(params[0], params[1], params[2])
        valid = (
            (rr >= 0) &
            (rr < img.shape[0]) &
            (cc >= 0) &
            (cc < img.shape[1])
        )

        img[rr[valid], cc[valid], 0] = 0
        img[rr[valid], cc[valid], 1] = 255
        img[rr[valid], cc[valid], 2] = 0

        print(iou((int(resx), int(resy), int(resr)), params))
        plt.imshow(img)
        plt.show()


def iou_loss(predict, target):
    predict = predict.detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    predict_arr = np.zeros((predict.shape[0], 1, 200, 200), dtype=np.float)
    target_arr = np.zeros((predict.shape[0], 1, 200, 200), dtype=np.float)
    for i in range(predict.shape[0]):
        p = predict[i]
        t = target[i]
        rr1, cc1 = circle(p[0], p[1], p[2])
        rr2, cc2 = circle(t[0], t[1], t[2])
        valid1 = (
            (rr1 >= 0) &
            (rr1 < 200) &
            (cc1 >= 0) &
            (cc1 < 200)
        )
        valid2 = (
            (rr2 >= 0) &
            (rr2 < 200) &
            (cc2 >= 0) &
            (cc2 < 200)
        )
        circle1 = np.zeros((200, 200), dtype=np.float)
        circle2 = np.zeros((200, 200), dtype=np.float)
        circle1[rr1[valid1], cc1[valid1]] = 1
        circle2[rr2[valid2], cc2[valid2]] = 1
        predict_arr[i,0,:,:] = circle1
        target_arr[i,0,:,:] = circle2
    predict = torch.tensor(predict_arr, requires_grad=True)
    target = torch.tensor(target_arr, requires_grad=True)

    intersect = predict * target
    intersect = intersect.view(predict.shape[0], 1, -1).sum(2)
    union = predict + target - (predict * target)
    union = union.view(predict.shape[0], 1, -1).sum(2)

    loss = intersect / union

    return 1 - loss.mean()


def train_model(models, trainloader, testloader, epochs, criterion, optimizer):
    mx, my, mr = models[0], models[1], models[2]
    ox, oy, Or = optimizer[0], optimizer[1], optimizer[2]
    for e in range(epochs):
        running_loss = 0
        for images, params in tqdm(trainloader):
            mx.train()
            my.train()
            mr.train()
            params = torch.stack(params, 1).float()
            if use_gpu:
                images = images.cuda()
                params = params.cuda()

            resx = mx(images.unsqueeze(1).float())
            resy = my(images.unsqueeze(1).float())
            resr = mr(images.unsqueeze(1).float())

            
            # loss = criterion(torch.stack((resx, resy, resr), dim=1), params) For IOU loss

            lossx = criterion(resx[:,0], params[:,0])
            lossy = criterion(resy[:,0], params[:,1])
            lossr = criterion(resr[:,0], params[:,2])
            ox.zero_grad()
            oy.zero_grad()
            Or.zero_grad()
            lossx.backward()
            lossy.backward()
            lossr.backward()
            
            # loss.backward()  # For IOU loss
            ox.step()
            oy.step()
            Or.step()
            running_loss += lossx.item() + lossy.item() + lossr.item()
            gc.collect()
        else:
            print(f"Epoch {e+1}")
            print(f"Training loss: {running_loss / len(trainloader)}")
            eval_model((mx, my, mr), testloader)
            # torch.save(mx.state_dict(), f'q2_checkpoints/q2_checkpoints_large2mx_{e+1}.pt')
            # torch.save(my.state_dict(), f'q2_checkpoints/q2_checkpoints_large2my_{e+1}.pt')
            # torch.save(mr.state_dict(), f'q2_checkpoints/q2_checkpoints_large2mr_{e+1}.pt')


def eval_model(models, testloader):
    mx, my, mr = models[0], models[1], models[2]
    mx.eval()
    my.eval()
    mr.eval()
    metric = 0
    for images, params in testloader:
        params = torch.stack(params, 1).float()
        if use_gpu:
            images = images.cuda()

        resx = mx(images.unsqueeze(1).float()).detach().cpu().numpy()[0]
        resy = my(images.unsqueeze(1).float()).detach().cpu().numpy()[0]
        resr = mr(images.unsqueeze(1).float()).detach().cpu().numpy()[0]
        res = (resx, resy, resr)
        params = params.numpy()[0]
        metric += iou(res, params)
    print(f"Average IoU for test set {metric / len(testloader)}")


def train():
    trainset = CircleDataset(200000)
    trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
    testset = CircleDataset(1000)
    testloader = DataLoader(testset, batch_size=1, shuffle=True)

    netx = Net()
    nety = Net()
    netr = Net()
    if use_gpu:
        netx = netx.cuda()
        nety = nety.cuda()
        netr = netr.cuda()

    criterion = nn.MSELoss()
    # criterion = iou_loss
    optimizerx = optim.Adam(netx.parameters(), lr=0.01)
    optimizery = optim.Adam(nety.parameters(), lr=0.01)
    optimizerr = optim.Adam(netr.parameters(), lr=0.01)
    epochs = 20
    train_model((netx, nety, netr), trainloader, testloader, epochs, criterion, (optimizerx, optimizery, optimizerr))
    eval_model((netx,nety,netr), testloader)


def main(do_train):
    if do_train:
        train()
        exit(0)


    results = []
    for _ in range(1000):
        params, img = noisy_circle(200, 50, 2)
        detected = find_circle(img)
        results.append(iou(params, detected))
    results = np.array(results)
    print(results.mean())
    print((results > 0.7).mean()) 


if __name__ == '__main__':
    main(False)

## 2.4 Visualization and error analysis

# 3.Hot Dog or Not Hot Dog