In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import cv2
from copy import deepcopy
from skimage import io
from tqdm.auto import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms as T
from torchvision.models import resnet50, vgg19
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection import fasterrcnn_resnet50_fpn, MaskRCNN, FasterRCNN, retinanet_resnet50_fpn
from torchvision.models.detection.rpn import AnchorGenerator

In [None]:
train_path = '../input/tensorflow-great-barrier-reef/train_images/'
df_path = '../input/tensorflow-great-barrier-reef/train.csv'
df = pd.read_csv(df_path).iloc[:, [0, 2, -1]]
test_df = pd.read_csv('../input/tensorflow-great-barrier-reef/test.csv').iloc[:, [0, 2]]
star_df = df[df['annotations'] != '[]']
non_df = df[df['annotations'] == '[]']
batch_size = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# first we pretrain the backbone for 5 epochs for classification to get better feature maps for detection.

In [None]:
class ClassiData(Dataset):
    def __init__(self, train_path, star_df, non_df):
        self.star = star_df
        self.star = np.concatenate((self.star, np.ones((len(self.star),1))), axis = 1)
        self.non = non_df
        self.non = np.concatenate((self.non, np.zeros((len(self.non),1))), axis = 1)
        self.random_choice()
        self.path = train_path
        self.transforms = T.Compose([T.ToPILImage(), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    def random_choice(self):
        rand = np.random.choice(len(self.non), len(self.star), replace = False)
        df = np.concatenate((self.star, self.non[rand]), axis = 0)
        np.random.shuffle(df)
        self.ids = df[:, 0]
        self.frames = df[:, 1]
        self.labels = df[:, -1]
    def __len__(self):
        return 2 * len(self.star)
    def __getitem__(self, idx):
        im_path = f'{self.path}video_{self.ids[idx]}/{self.frames[idx]}.jpg'
        label = self.labels[idx]
        image = io.imread(im_path)
        image = torch.from_numpy(image)
        image = image.permute((2, 0, 1))
        image = self.transforms(image)
        if np.random.rand() < 0.5:
            image = image.flip(-1)
        if idx == len(self):
            self.random_choice()
        return image, torch.tensor([label]).float()

In [None]:
torch.cuda.empty_cache()

In [None]:
resnet = torch.load('../input/resnet50/resnet50.pth')
backbone = resnet50()
backbone.load_state_dict(resnet)
backbone.avgpool = nn.AdaptiveAvgPool2d(7)
backbone.fc = nn.Sequential(nn.Linear(7*7*2048, 1024),
                            nn.ReLU(),
                            nn.Dropout(0.1),
                            nn.Linear(1024, 256),
                            nn.ReLU(),
                            nn.Dropout(0.1),
                            nn.Linear(256, 1),
                            nn.Sigmoid())

In [None]:
train_data = ClassiData(train_path, star_df.values, non_df.values)
train_data = DataLoader(train_data, batch_size = batch_size)
optimizer = torch.optim.Adam(backbone.parameters(), lr = 0.0001)
criteria = nn.BCELoss()

In [None]:
def train_model_classi(model, train_data, criteria, optimizer, epochs, val_data = None, device = device):
    min_val_loss = 10
    model_state = None
    model.to(device)
    for epoch in range(epochs):
        tr_loss = 0
        val_loss = 0
        train_step = 0
        acc = 0
        val_acc = 0
        for images, targets in tqdm(train_data):
            images = images.to(device)
            targets = targets.to(device)
            pred = model(images)
            loss = criteria(pred, targets)
            train_step += 1
            if train_step <= np.ceil(int(len(train_data)*0.8) / batch_size):
                acc += torch.sum(pred.detach().cpu() > 0.5).item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                tr_loss += loss.item()
            else:
                val_acc += torch.sum(pred.detach().cpu() > 0.5).item()
                val_loss += loss.item()
        
        tr_loss /= (int(len(train_data)*0.8))
        acc /= (int(len(train_data)*0.8))
        val_loss /= (len(train_data) - int(len(train_data)*0.8))
        val_acc /= (len(train_data) - int(len(train_data)*0.8))
        torch.cuda.empty_cache()
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            model_state = deepcopy(model.state_dict())
        print(f'epoch: {epoch} loss: {tr_loss} acc: {acc} ;;;; val_loss: {val_loss} val_acc: {val_acc}')
    model.load_state_dict(model_state)

In [None]:
train_model_classi(backbone, train_data, criteria, optimizer, 5)
torch.cuda.empty_cache()
batch_size = 1

# now we implement the dataset for object detection. this dataset returns the images and their respective target dictionary. target dictionaries consist of target bounding boxes, areas and labels.

In [None]:
class StarfishData(Dataset):
    def __init__(self, train_path, df):
        np.random.shuffle(df)
        self.ids = df[:, 0]
        self.frames = df[:, 1]
        self.annotations = df[:, -1]
        self.path = train_path
        self.transforms = T.Compose([T.ToPILImage(), T.ToTensor()])
    def __len__(self):
        return len(self.ids)
    def __getitem__(self, idx):
        im_path = f'{self.path}video_{self.ids[idx]}/{self.frames[idx]}.jpg'
        annot = eval(self.annotations[idx])
        image = io.imread(im_path)
        boxes = [[x['x'], x['y'], x['x'] + x['width'], x['y'] + x['height']] for x in annot]
        areas = [x['width'] * x['height'] for x in annot]
        is_crowd = [0] * len(areas)
        labels = [1] * len(areas)
        image = torch.from_numpy(image)
        boxes = torch.as_tensor(boxes).type(torch.float32)
        areas = torch.as_tensor(areas).type(torch.float32)
        is_crowd = torch.as_tensor(is_crowd).type(torch.int64)
        labels = torch.as_tensor(labels).type(torch.int64)
        image = image.permute((2, 0, 1))
        image = self.transforms(image)
        if np.random.rand() < 0.5:
            image = image.flip(-1)
            boxes[:, [1, 3]] = image.size()[-1] - boxes[:, [3, 1]]
        target = {}
        target['boxes'] = boxes
        target['area'] = areas
        target['iscrowd'] = is_crowd
        target['labels'] = labels
        return image, target

In [None]:
files = star_df.values
val_ind = np.random.choice(len(files), len(files) // 5, replace = False)
train_ind = np.delete(np.arange(len(files)), val_ind)
train_files = files[train_ind]
val_files = files[val_ind]

In [None]:
train_data = StarfishData(train_path, np.concatenate((train_files, val_files), axis = 0))
train_data = DataLoader(train_data, batch_size = batch_size, collate_fn = lambda x:tuple(zip(*x)))
val_data = StarfishData(train_path, val_files)
val_data = DataLoader(val_data)

# here we read load the pretrained weights for the backbone.

In [None]:
# vgg = torch.load('../input/vgg19/vgg19.pth')
# backbone = torchvision.models.vgg19()
# backbone.load_state_dict(vgg)
# backbone = backbone.features
# backbone.out_channels = 512
# anchor_generator = AnchorGenerator(sizes = ((64,128,256,512),) , aspect_ratios = ((0.5 , 1 , 2),))
# roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names = ['0'] , output_size = 7 , sampling_ratio = 2)
# model = FasterRCNN(backbone , rpn_anchor_generator = anchor_generator , box_roi_pool = roi_pooler , num_classes = 2)

In [None]:
resnet = torch.load('../input/resnet50/resnet50.pth')
resnet = backbone.state_dict()
model = fasterrcnn_resnet50_fpn(pretrained_backbone = False, num_classes = 2)
# model.transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size = 100, max_size = 1300, fixed_size = (360, 640), image_mean = [0.485, 0.456, 0.406],image_std = [0.229, 0.224, 0.225])
state = model.backbone.body.state_dict()
for k in state.keys():
    if k in resnet.keys():
        state[k] = resnet[k]
model.backbone.body.load_state_dict(state)

<All keys matched successfully>

# finally it is time to train the model. we read images and targets from dataset and we calcuate the loss then we sum up all the losses and we back propagate and train the weights.

In [None]:
def train_model(model, train_data, optimizer, epochs, val_data = None, device = device):
    min_val_loss = 10
    model_state = None
    model.to(device)
    for epoch in range(epochs):
        keys = None
        tr_loss = None
        val_loss = None
        train_step = 0
        for images, targets in tqdm(train_data):
            images = [image.to(device) for image in images]
            targets = [{k:v.to(device) for k , v in t.items()} for t in targets]
            loss_dict = model(images , targets)
            train_step += 1
            if train_step <= np.ceil(len(train_ind) / batch_size):
                loss = sum([x for x in loss_dict.values()])
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                tr_loss = torch.stack(list(loss_dict.values())).detach().cpu() if tr_loss is None else tr_loss + torch.stack(list(loss_dict.values())).detach().cpu()
            else:
                val_loss = torch.stack(list(loss_dict.values())).detach().cpu() if val_loss is None else val_loss + torch.stack(list(loss_dict.values())).detach().cpu()
        
        keys = list(loss_dict.keys())
        tr_loss /= len(train_ind)
        train_str = [f'{keys[i]}:{tr_loss[i].item()}' for i in range(len(keys))]
        train_str = str.join('||' , train_str)
        val_loss /= len(val_ind)
        val_str = [f'{keys[i]}:{val_loss[i].item()}' for i in range(len(keys))]
        val_str = str.join('||' , val_str)
        torch.cuda.empty_cache()
        if torch.sum(val_loss) < min_val_loss:
            min_val_loss = torch.sum(val_loss)
            model_state = deepcopy(model.state_dict())
        print(f'{train_str};;;;{val_str}')
    model.load_state_dict(model_state)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)

In [None]:
train_model(model, train_data, optimizer, 20)

  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.0926336869597435||loss_box_reg:0.01926822029054165||loss_objectness:0.5660269856452942||loss_rpn_box_reg:4.120506286621094;;;;loss_classifier:0.027312489226460457||loss_box_reg:0.006774798035621643||loss_objectness:0.5218324661254883||loss_rpn_box_reg:3.6710381507873535


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.03353343531489372||loss_box_reg:0.014418660663068295||loss_objectness:0.5505380034446716||loss_rpn_box_reg:3.6624109745025635;;;;loss_classifier:0.05336063355207443||loss_box_reg:0.012831361964344978||loss_objectness:0.5959644913673401||loss_rpn_box_reg:3.7681069374084473


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.03306502476334572||loss_box_reg:0.01581144891679287||loss_objectness:0.5388669371604919||loss_rpn_box_reg:3.0779714584350586;;;;loss_classifier:0.03641708567738533||loss_box_reg:0.022692520171403885||loss_objectness:0.510316789150238||loss_rpn_box_reg:2.4562265872955322


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.03716249763965607||loss_box_reg:0.021008087322115898||loss_objectness:0.5227206349372864||loss_rpn_box_reg:2.716264247894287;;;;loss_classifier:0.032086439430713654||loss_box_reg:0.023462621495127678||loss_objectness:0.49201616644859314||loss_rpn_box_reg:2.4563796520233154


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.03503391146659851||loss_box_reg:0.022677088156342506||loss_objectness:0.5097997784614563||loss_rpn_box_reg:2.6003096103668213;;;;loss_classifier:0.035743121057748795||loss_box_reg:0.02410336397588253||loss_objectness:0.5194811820983887||loss_rpn_box_reg:2.711249828338623


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.038966622203588486||loss_box_reg:0.02608516439795494||loss_objectness:0.5212851762771606||loss_rpn_box_reg:2.5547213554382324;;;;loss_classifier:0.03614703193306923||loss_box_reg:0.029005736112594604||loss_objectness:0.5187128782272339||loss_rpn_box_reg:2.583094596862793


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.03587023541331291||loss_box_reg:0.029167892411351204||loss_objectness:0.5025084018707275||loss_rpn_box_reg:2.3822107315063477;;;;loss_classifier:0.033442284911870956||loss_box_reg:0.029568595811724663||loss_objectness:0.5175173878669739||loss_rpn_box_reg:2.6116812229156494


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.03839766979217529||loss_box_reg:0.0338815301656723||loss_objectness:0.49472498893737793||loss_rpn_box_reg:2.302732229232788;;;;loss_classifier:0.03684091940522194||loss_box_reg:0.036520205438137054||loss_objectness:0.4913114011287689||loss_rpn_box_reg:2.3034512996673584


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.04054247960448265||loss_box_reg:0.03499697893857956||loss_objectness:0.5152475833892822||loss_rpn_box_reg:2.473705530166626;;;;loss_classifier:0.03264222294092178||loss_box_reg:0.031113430857658386||loss_objectness:0.5082297921180725||loss_rpn_box_reg:2.6827752590179443


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.04442007839679718||loss_box_reg:0.03866293653845787||loss_objectness:0.5009245872497559||loss_rpn_box_reg:2.366549015045166;;;;loss_classifier:0.031394876539707184||loss_box_reg:0.028515568003058434||loss_objectness:0.523177444934845||loss_rpn_box_reg:2.6895787715911865


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.04819599539041519||loss_box_reg:0.04195871949195862||loss_objectness:0.4912000298500061||loss_rpn_box_reg:2.2549612522125244;;;;loss_classifier:0.03462108224630356||loss_box_reg:0.035490937530994415||loss_objectness:0.4714031219482422||loss_rpn_box_reg:2.127817392349243


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.05084967240691185||loss_box_reg:0.04539933428168297||loss_objectness:0.49549755454063416||loss_rpn_box_reg:2.2549617290496826;;;;loss_classifier:0.042075030505657196||loss_box_reg:0.04318547621369362||loss_objectness:0.48796743154525757||loss_rpn_box_reg:2.411322593688965


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.054211266338825226||loss_box_reg:0.05147261172533035||loss_objectness:0.4708094298839569||loss_rpn_box_reg:2.1552798748016357;;;;loss_classifier:0.03651880472898483||loss_box_reg:0.04394332692027092||loss_objectness:0.5319865345954895||loss_rpn_box_reg:2.5058372020721436


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.058042991906404495||loss_box_reg:0.05474770441651344||loss_objectness:0.47629666328430176||loss_rpn_box_reg:2.150578022003174;;;;loss_classifier:0.04464937746524811||loss_box_reg:0.05095789581537247||loss_objectness:0.4899756610393524||loss_rpn_box_reg:2.1604809761047363


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.06055605039000511||loss_box_reg:0.050765857100486755||loss_objectness:0.5314667224884033||loss_rpn_box_reg:2.3206276893615723;;;;loss_classifier:0.06890127062797546||loss_box_reg:0.06470700353384018||loss_objectness:0.4707297086715698||loss_rpn_box_reg:2.1211154460906982


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.06550464034080505||loss_box_reg:0.062001243233680725||loss_objectness:0.4759283661842346||loss_rpn_box_reg:2.1334495544433594;;;;loss_classifier:0.04507775604724884||loss_box_reg:0.057751696556806564||loss_objectness:0.44580450654029846||loss_rpn_box_reg:1.9554444551467896


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.061220645904541016||loss_box_reg:0.05205690860748291||loss_objectness:0.48818740248680115||loss_rpn_box_reg:2.2653932571411133;;;;loss_classifier:0.05229996517300606||loss_box_reg:0.0417497381567955||loss_objectness:0.48570695519447327||loss_rpn_box_reg:2.434040069580078


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.05689128860831261||loss_box_reg:0.056882020086050034||loss_objectness:0.4784905016422272||loss_rpn_box_reg:2.1741628646850586;;;;loss_classifier:0.09095174819231033||loss_box_reg:0.07436792552471161||loss_objectness:0.44322827458381653||loss_rpn_box_reg:1.9288464784622192


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.05351675674319267||loss_box_reg:0.055700644850730896||loss_objectness:0.4779527485370636||loss_rpn_box_reg:2.130586862564087;;;;loss_classifier:0.04211590439081192||loss_box_reg:0.05344346538186073||loss_objectness:0.44754552841186523||loss_rpn_box_reg:1.994939923286438


  0%|          | 0/4919 [00:00<?, ?it/s]

loss_classifier:0.06297570466995239||loss_box_reg:0.059919629245996475||loss_objectness:0.4694163203239441||loss_rpn_box_reg:2.147547960281372;;;;loss_classifier:0.05837774649262428||loss_box_reg:0.05128508433699608||loss_objectness:0.43657904863357544||loss_rpn_box_reg:2.0517847537994385


# in the end model is ready to be tested and in the following cells you can see how.

In [None]:
class StarfishTest(Dataset):
    def __init__(self, train_path, df):
        self.path = train_path
        self.ids = df[:, 0]
        self.frames = df[:, 1]
        self.transforms = T.Compose([T.ToPILImage(), T.ToTensor()])
    def __len__(self):
        return len(self.ids)
    def __getitem__(self, idx):
        im_path = f'{self.path}video_{self.ids[idx]}/{self.frames[idx]}.jpg'
        image = io.imread(im_path)
        image = torch.from_numpy(image)
        image = image.permute((2, 0, 1))
        image = self.transforms(image)
        return image

In [None]:
test_data = StarfishTest(train_path, test_df.values)
# test_data = DataLoader(test_data, collate_fn = lambda x:tuple(zip(*x)))

In [None]:
def get_pred(pred):
    p = [f'{s} {b[0]} {b[1]} {b[2] - b[0]} {b[3] - b[1]}' for s,b,l in zip(pred['scores'],pred['boxes'],pred['labels']) if s > 0.5 and l == 1]
    return ' '.join(p)

In [None]:
import greatbarrierreef
model.eval()
transform = T.Compose([T.ToPILImage(), T.ToTensor()])
env = greatbarrierreef.make_env()
iter_test = env.iter_test()
for (pixel_array, sample_prediction_df) in iter_test:
    im = transform(torch.from_numpy(pixel_array).permute((2, 0, 1))).to(device)
    sample_prediction_df['annotations'] = get_pred(model([im])[0])
    env.predict(sample_prediction_df)