In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import torch
import json
import pathlib
import numpy as np
import pytorch_lightning as pl
from torch import nn
from torchvision.transforms import transforms
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.loggers import WandbLogger
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from PIL import Image
from torchmetrics import Accuracy

In [None]:
def calc_weights(counts, dampen = 2):
    weights = []
    for c in counts:
        v = 1 / (c * sum([1 / v for v in counts]))
        v = v ** (1 / dampen)
        weights.append(v)
        
    # make weights add up to 1
    weights = [w / sum(weights) for w in weights]  

    return weights

In [None]:
def get_data(data_dir, mode):
    json_files = list(data_dir.glob("*.json"))
    assert len(json_files) == 1
    data = json.load(open(str(json_files[0]), 'rb'))
    
    # Build map of image_id to file_name.
    id_to_filename = {o['id']: o['file_name'] for o in data['images']}
    
    # Build maps between category_id and pytorch_index.
    category_id_to_index = {0: 0, 2: 1, 4: 2, 5: 3}
    index_to_category_id = {v: k for k, v in category_id_to_index.items()} 
    
    # Build list of data.
    data = [
        {
            'file_name': data_dir / "images" / id_to_filename[o['image_id']],
            'label': category_id_to_index[o['category_id']]
        } 
        for o in data['annotations'] if o['category_id'] in
        list(category_id_to_index.keys())
    ]
    
    def _use(file_name, mode):
        value = str(file_name.parent.parent)[-1]
        if mode == 'train':
            return int(value) in [0, 1, 2, 3, 4, 5, 6, 7]
        else:
            return int(value) in [8, 9]
        
    # Validate the data exists.
    data = [
        o for o in data 
        if pathlib.Path(data_dir / "images" / o['file_name']).exists()
        and _use(o['file_name'], mode)
    ]
    
    return data

In [None]:
data_train = get_data(pathlib.Path("/home/harry/DATA/animal_data/channel-islands/"), 'train')
data_val = get_data(pathlib.Path("/home/harry/DATA/animal_data/channel-islands/"), 'valid')
print(len(data_train), len(data_val))

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, data):
        
        self.data = data
        
        # Transforms
        self.transforms = transforms.Compose([
            transforms.Resize((224, 224),interpolation=Image.NEAREST),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = Image.open(self.data[idx]['file_name'])
        transformed_image = self.transforms(image)        
        label = torch.Tensor([self.data[idx]['label']]).long()
        return transformed_image, label

In [None]:
class Model(pl.LightningModule):
    
    
    def __init__(self, data_train, data_val):
        super().__init__()
        
        self.data_train = data_train
        self.data_val = data_val
   
        backbone = models.resnet50(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        _fc_layers = [
            nn.Linear(2048, 256), 
            nn.ReLU(), 
            nn.Linear(256, 32),
            nn.Linear(32, 4)
        ]
        self.fc = nn.Sequential(*_fc_layers)
        
        weights = torch.tensor(self._calc_loss_weights(self.data_train))
        self.loss_function = nn.CrossEntropyLoss(weight=weights)
        
        self.accuracy = Accuracy(average='weighted', num_classes=4)    
        
        self.validation_prediction_list = []
        self.validation_gt_list = []
        
    def _calc_loss_weights(self, data_train):
        counts = [0, 0, 0, 0]
        for d in data_train:
            counts[int(d['label'])] += 1

        weights = calc_weights(counts, dampen=2)

        return weights
    
    def _calc_metrics(self, pred, gt):

        gt = gt.long().cpu().detach().numpy()
        pred = np.argmax(pred.cpu().detach().numpy(), axis=-1)

        kwargs = {'average': 'weighted'}

        f1 = f1_score(y_true=gt, y_pred=pred, **kwargs)
        recall = recall_score(y_true=gt, y_pred=pred, **kwargs)
        precision = precision_score(y_true=gt, y_pred=pred, **kwargs)
        accuracy = accuracy_score(y_true=gt, y_pred=pred)

        metrics = {
            'f1': torch.tensor(f1, dtype=torch.float32, device=self.device),
            'recall': torch.tensor(recall, dtype=torch.float32, device=self.device),
            'precision': torch.tensor(precision, dtype=torch.float32, device=self.device),
            'accuracy': torch.tensor(accuracy, dtype=torch.float32, device=self.device)
        }

        return metrics
        
    def forward(self, x):
        
        x = self.feature_extractor(x)
        x = x.squeeze(-1).squeeze(-1)
        x = self.fc(x)
        
        return F.log_softmax(x, dim=1)
        
        
    def training_step(self, batch, batch_idx):
        x, y = batch

        preds = self.forward(x)
        loss = self.loss_function(preds, y.squeeze())
        metrics = self._calc_metrics(pred=preds, gt=y.squeeze())
        
        self.log('loss', loss)
        for k, v in metrics.items():
            self.log(k, v)

        return loss
    
    
    def validation_step(self, batch, batch_idx):
        x, y = batch

        preds = self.forward(x)        
        loss = self.loss_function(preds, y.squeeze())
        
        self.validation_prediction_list.append(preds)
        self.validation_gt_list.append(y.squeeze())

        self.log('val_loss', loss)

        return loss
    
    def validation_epoch_end(self, outputs):

        prediction_logits = torch.cat(self.validation_prediction_list, dim=0)
        gt = torch.cat(self.validation_gt_list, dim=0)

        metrics = self._calc_metrics(prediction_logits, gt)
        for k, v in metrics.items():
            self.log(f'val_{k}', v)
    
    
    def train_dataloader(self):
        
        train_ds = CustomImageDataset(self.data_train)        
        return DataLoader(train_ds, batch_size=32, num_workers=4, shuffle=True)
    
    
    def val_dataloader(self):
        
        valid_ds = CustomImageDataset(self.data_val)        
        return DataLoader(valid_ds, batch_size=32, num_workers=4)
    

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
model = Model(data_train, data_val)
wandb_logger = WandbLogger(project="thea")
trainer = pl.Trainer(gpus=1, logger=wandb_logger)
#trainer = pl.Trainer(gpus=1)
trainer.fit(model)