# multi-task_model-train

`multi-task_model-train.ipynb`

End of August attempts to create good model training workflows for multi-task experiments

Author: Jacob A Rose  
Created on: Monday August 29th, 2021

In [1]:
%load_ext autoreload
%autoreload 2

# Imports & definitions


In [2]:
from IPython.core.interactiveshell import InteractiveShell
# pretty print all cell's output and not just the last one
InteractiveShell.ast_node_interactivity = "all"
import os

if 'TOY_DATA_DIR' not in os.environ: 
    os.environ['TOY_DATA_DIR'] = "/media/data_cifs/projects/prj_fossils/data/toy_data"
default_root_dir = os.environ['TOY_DATA_DIR']

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', 200)

import torch
from torch import nn
import torchvision
from torchvision import transforms
# from lightning_hydra_classifiers.data.utils import make_catalogs
# import torchdata

# import albumentations as A
import pytorch_lightning as pl
import timm
from rich import print as pp

import matplotlib.pyplot as plt
import pandas as pd
from munch import Munch

from lightning_hydra_classifiers.data.utils.make_catalogs import *

from lightning_hydra_classifiers.utils.metric_utils import get_per_class_metrics, get_scalar_metrics
from lightning_hydra_classifiers.utils.logging_utils import get_wandb_logger
import wandb

torch.manual_seed(17)

<torch._C.Generator at 0x7fb6311eac90>

## Datasets & DataModules

In [3]:
from lightning_hydra_classifiers.experiments.transfer_experiment import TransferExperiment




class PlantDataModule(pl.LightningDataModule):
#     valid_tasks = (0, 1)
    
    def __init__(self, 
                 batch_size,
                 task_id: int=0,
                 image_size: int=224,
                 image_buffer_size: int=32,
                 num_workers: int=4,
                 pin_memory: bool=True):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        
        
        self.experiment = TransferExperiment()
        self.set_task(task_id)        
        
        self.image_size = image_size
        self.image_buffer_size = image_buffer_size
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]
        # Train augmentation policy
        
        self.__init_transforms()
                
        self.tasks = self.experiment.get_multitask_datasets(train_transform=self.train_transform,
                                                            val_transform=self.val_transform)


    def __init_transforms(self):
        
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=self.image_size,
                                         scale=(0.25, 1.2),
                                         ratio=(0.7, 1.3),
                                         interpolation=2),
            torchvision.transforms.ToTensor(),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(self.mean, self.std),
            transforms.Grayscale(num_output_channels=3)
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(self.image_size+self.image_buffer_size),
            torchvision.transforms.ToTensor(),
            transforms.CenterCrop(self.image_size),
            transforms.Normalize(self.mean, self.std),
            transforms.Grayscale(num_output_channels=3)            
        ])

    def set_task(self, task_id: int):
        assert task_id in self.experiment.valid_tasks
        self.task_id = task_id
        
        
        
    @property
    def current_task(self):
        return self.tasks[self.task_id]

    def setup(self, stage=None):
        task = self.current_task
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.train_dataset = task['train']
            self.val_dataset = task['val']
            
            self.classes = self.train_dataset.classes
            self.num_classes = len(self.train_dataset.label_encoder)
            
        elif stage == 'test':
            self.test_dataset = task['test']
                        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          pin_memory=self.pin_memory,
                          shuffle=True,
                          num_workers=self.num_workers,
                          drop_last=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          pin_memory=self.pin_memory,
                          num_workers=self.num_workers)
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          pin_memory=self.pin_memory,
                          num_workers=self.num_workers)

## Model & LightningModules

In [4]:
class CustomResNet(nn.Module):
    def __init__(self,
                 num_classes: int,
                 model_name='resnet18',
                 pretrained=False):
        super().__init__()
        self.num_classes = num_classes
        self.model = timm.create_model(model_name, pretrained=pretrained)
        self.in_features = self.model.get_classifier().in_features
        self.model.fc = nn.Linear(self.in_features, self.num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

class LitMultiTaskModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.lr = config['lr']
        self.num_classes = config['num_classes']
#         self.save_hyperparameters()
        self._init_model(config)
        self.metrics = self._init_metrics(stage='all')
#         self.metric = pl.metrics.F1(num_classes=CONFIG['num_classes'])
        self.criterion = nn.CrossEntropyLoss()
        

    def forward(self, x, *args, **kwargs):
        return self.model(x)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.config['t_max'], eta_min=self.config['min_lr'])

        return {'optimizer': self.optimizer, 'lr_scheduler': self.scheduler}

    def training_step(self, batch, batch_idx):
        image = batch[0]
        target = batch[1]
        output = self.model(image)
        loss = self.criterion(output, target)
#         scores = self.metrics_train(output.argmax(1), target)
        scores = self.metrics_train(output, target)
        self.log_dict({"train_loss": loss, 'lr': self.optimizer.param_groups[0]['lr']},
                      on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log_dict(scores,
                      on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        image = batch[0]
        target = batch[1]
        output = self.model(image)
        loss = self.criterion(output, target)
#         scores = self.metrics_val(output.argmax(1), target)
        scores = self.metrics_val(output, target)
        
        self.log("val_loss", loss,
                  on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log_dict(scores,
                      on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    
    def _init_model(self, config):
        self.model =  CustomResNet(config["num_classes"],
                                   model_name=config["model_name"],
                                   pretrained=config["pretrained"])
    
    def _init_metrics(self, stage: str='train'):
        
        if stage in ['train', 'all']:
            self.metrics_train = get_scalar_metrics(num_classes=self.num_classes, average='macro', prefix='train')
#             self.metrics_train_per_class = get_per_class_metrics(num_classes=self.num_classes, prefix='train')
            
        if stage in ['val', 'all']:
            self.metrics_val = get_scalar_metrics(num_classes=self.num_classes, average='macro', prefix='val')
#             self.metrics_val_per_class = get_per_class_metrics(num_classes=self.num_classes, prefix='val')
            
        if stage in ['test', 'all']:
            self.metrics_test = get_scalar_metrics(num_classes=self.num_classes, average='macro', prefix='test')
#             self.metrics_test_per_class = get_per_class_metrics(num_classes=self.num_classes, prefix='test')

    

# Define & Run Experiment

## Config

In [5]:
#     "model":
#         {"backbone":{
#                  "name":'resnet50',
#                  "pretrained":True},
config = Munch({
    "seed":42,
    "model_name":'resnet50',
    "pretrained":True,
    "image_size": 224,
    "image_buffer_size": 32, 
    "num_classes": None,
    "lr": 5e-4,
    "min_lr": 1e-6,
    "t_max": 20,
    "num_epochs": 10,
    "batch_size": 32,
#     accum = 1,
    "precision": 16,
    "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    "num_workers": 4,
    "pin_memory": True
})


# Seed everything
pl.seed_everything(config['seed'])

Global seed set to 42


42

## DataModule

In [6]:
datamodule = PlantDataModule(batch_size=config.batch_size,
                             task_id=0,
                             image_size=config.image_size,
                             image_buffer_size=config.image_buffer_size,
                             num_workers=config.num_workers,
                             pin_memory=config.pin_memory)

datamodule.setup("fit")
config.num_classes = datamodule.num_classes

pp(config)
model = LitMultiTaskModule(config)

## Callbacks

In [7]:
# Checkpoint
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss',
                                      save_top_k=1,
                                      save_last=True,
                                      save_weights_only=True,
                                      filename='checkpoint/{epoch:02d}-{val_loss:.4f}-{val_f1:.4f}',
                                      verbose=True,
                                      mode='min')
earlystopping = pl.callbacks.EarlyStopping(monitor='val_loss', patience=3, mode='min')

## Logger

In [8]:
wandb_logger = pl.loggers.WandbLogger(entity = "jrose",
                           project = "image_classification_train",
                           job_type = "train_supervised",
                           config=config,
                           group='ResNet')

## Trainer

In [9]:
# Initialize a trainer
trainer = pl.Trainer(
            limit_train_batches=0.1,
            limit_val_batches=0.1,
            max_epochs=config['num_epochs'],
            gpus=1,
#             accumulate_grad_batches=CONFIG['accum'],
            precision=config['precision'],
            callbacks=[earlystopping,
                       checkpoint_callback],
#                        ImagePredictionLogger(val_samples)],
#             checkpoint_callback=checkpoint_callback,
            logger=wandb_logger,
            weights_summary='top')

# datamodule.train_dataset[0]
type(datamodule.val_dataset)

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


lightning_hydra_classifiers.data.utils.make_catalogs.CSVDataset

In [None]:
from torchinfo import summary

# model_stats = summary(your_model, (1, 3, 28, 28), verbose=0)

# TRAIN

In [10]:
# Train the model ‚ö°üöÖ‚ö°
trainer.fit(model, datamodule)

# Close wandb run
wandb.finish() 

# datamodule.num_classes
# sorted(datamodule.classes)

  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [5]
[34m[1mwandb[0m: Currently logged in as: [33mjrose[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name          | Type             | Params
---------------------------------------------------
0 | model         | CustomResNet     | 23.7 M
1 | metrics_train | MetricCollection | 0     
2 | metrics_val   | MetricCollection | 0     
3 | metrics_test  | MetricCollection | 0     
4 | criterion     | CrossEntropyLoss | 0     
---------------------------------------------------
23.7 M    Trainable params
0         Non-trainable params
23.7 M    Total params
94.786    Total estimated model params size (MB)


                                                                                                                                                                                             

Global seed set to 42


Epoch 0:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé                | 51/64 [00:33<00:08,  1.54it/s, loss=3.8, v_num=xp39, train_loss_step=3.930, lr_step=0.0005]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Epoch 0:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ              | 53/64 [00:35<00:07,  1.51it/s, loss=3.8, v_num=xp39, train_loss_step=3.930, lr_step=0.0005][A
Validating:  15%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã                                       

Epoch 0, global step 50: val_loss reached 4.17365 (best 4.17365), saving model to "/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/image_classification_train/2xf9xp39/checkpoints/checkpoint/epoch=00-val_loss=4.1736-val_f1=0.0000.ckpt" as top 1


Epoch 1:  81%|‚ñä| 52/64 [00:33<00:07,  1.58it/s, loss=3.61, v_num=xp39, train_loss_step=3.180, lr_step=0.000497, val_loss_step=4.910, val_loss_epoch=4.170, val/F1_top1=0.0162, val/acc_top1=0
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Epoch 1:  84%|‚ñä| 54/64 [00:35<00:06,  1.56it/s, loss=3.61, v_num=xp39, train_loss_step=3.180, lr_step=0.000497, val_loss_step=4.910, val_loss_epoch=4.170, val/F1_top1=0.0162, val/acc_top1=0[A
Epoch 1:  88%|‚ñâ| 56/64 [00:35<00:04,  1.61it/s, loss=3.61, v_num=xp39, train_loss_step=3.180, lr_step=0.000497, val_loss_step=4.910, val_loss_epoch=4.170, val/F1_top1=0.0162, val/acc_top1=0[A
Validating:  38%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè         

Epoch 1, global step 101: val_loss reached 3.75095 (best 3.75095), saving model to "/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/image_classification_train/2xf9xp39/checkpoints/checkpoint/epoch=01-val_loss=3.7510-val_f1=0.0000.ckpt" as top 1


Epoch 2:  81%|‚ñä| 52/64 [00:30<00:06,  1.72it/s, loss=3.36, v_num=xp39, train_loss_step=3.130, lr_step=0.000488, val_loss_step=3.920, val_loss_epoch=3.750, val/F1_top1=0.0221, val/acc_top1=0
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Validating:   8%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                                                                                                                  | 1/13 [00:01<00:23,  1.98s/it][A
Epoch 2:  84%|‚ñä| 54/64 [00:33<00:06,  1.66it/s, loss=3.36, v_num=xp39, train_loss_step=3.130, lr_step=0.000488, val_loss_step=3.920, val_loss_epoch=3.750, val/F1_top1=0.0221, val/acc_top1=0[A
Validating:  23%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå                                                       

Epoch 2, global step 152: val_loss reached 3.41573 (best 3.41573), saving model to "/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/image_classification_train/2xf9xp39/checkpoints/checkpoint/epoch=02-val_loss=3.4157-val_f1=0.0000.ckpt" as top 1


Epoch 3:  81%|‚ñä| 52/64 [00:30<00:06,  1.74it/s, loss=3.23, v_num=xp39, train_loss_step=3.310, lr_step=0.000473, val_loss_step=3.130, val_loss_epoch=3.420, val/F1_top1=0.0238, val/acc_top1=0
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Validating:   8%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                                                                                                                  | 1/13 [00:02<00:28,  2.39s/it][A
Epoch 3:  84%|‚ñä| 54/64 [00:33<00:06,  1.65it/s, loss=3.23, v_num=xp39, train_loss_step=3.310, lr_step=0.000473, val_loss_step=3.130, val_loss_epoch=3.420, val/F1_top1=0.0238, val/acc_top1=0[A
Validating:  23%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå                                                       

Epoch 3, global step 203: val_loss was not in top 1


Epoch 4:  81%|‚ñä| 52/64 [00:31<00:07,  1.71it/s, loss=3.31, v_num=xp39, train_loss_step=3.120, lr_step=0.000452, val_loss_step=3.350, val_loss_epoch=3.620, val/F1_top1=0.0205, val/acc_top1=0
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Validating:   8%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                                                                                                                  | 1/13 [00:02<00:27,  2.30s/it][A
Epoch 4:  84%|‚ñä| 54/64 [00:33<00:06,  1.63it/s, loss=3.31, v_num=xp39, train_loss_step=3.120, lr_step=0.000452, val_loss_step=3.350, val_loss_epoch=3.620, val/F1_top1=0.0205, val/acc_top1=0[A
Validating:  23%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå                                                       

Epoch 4, global step 254: val_loss reached 3.30863 (best 3.30863), saving model to "/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/image_classification_train/2xf9xp39/checkpoints/checkpoint/epoch=04-val_loss=3.3086-val_f1=0.0000.ckpt" as top 1


Epoch 5:  81%|‚ñä| 52/64 [00:30<00:06,  1.73it/s, loss=3.24, v_num=xp39, train_loss_step=3.160, lr_step=0.000427, val_loss_step=3.160, val_loss_epoch=3.310, val/F1_top1=0.0346, val/acc_top1=0
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Validating:   8%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                                                                                                                  | 1/13 [00:02<00:25,  2.13s/it][A
Epoch 5:  84%|‚ñä| 54/64 [00:33<00:06,  1.66it/s, loss=3.24, v_num=xp39, train_loss_step=3.160, lr_step=0.000427, val_loss_step=3.160, val_loss_epoch=3.310, val/F1_top1=0.0346, val/acc_top1=0[A
Validating:  23%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå                                                       

Epoch 5, global step 305: val_loss reached 3.13088 (best 3.13088), saving model to "/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/image_classification_train/2xf9xp39/checkpoints/checkpoint/epoch=05-val_loss=3.1309-val_f1=0.0000.ckpt" as top 1


Epoch 6:  81%|‚ñä| 52/64 [00:31<00:07,  1.67it/s, loss=2.99, v_num=xp39, train_loss_step=2.980, lr_step=0.000397, val_loss_step=3.280, val_loss_epoch=3.130, val/F1_top1=0.0356, val/acc_top1=0
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Validating:   8%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                                                                                                                  | 1/13 [00:01<00:22,  1.91s/it][A
Epoch 6:  84%|‚ñä| 54/64 [00:33<00:06,  1.62it/s, loss=2.99, v_num=xp39, train_loss_step=2.980, lr_step=0.000397, val_loss_step=3.280, val_loss_epoch=3.130, val/F1_top1=0.0356, val/acc_top1=0[A
Validating:  23%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå                                                       

Epoch 6, global step 356: val_loss was not in top 1


Epoch 7:  81%|‚ñä| 52/64 [00:32<00:07,  1.66it/s, loss=2.99, v_num=xp39, train_loss_step=2.820, lr_step=0.000364, val_loss_step=2.670, val_loss_epoch=3.200, val/F1_top1=0.0323, val/acc_top1=0
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Validating:   8%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                                                                                                                  | 1/13 [00:01<00:23,  1.94s/it][A
Epoch 7:  84%|‚ñä| 54/64 [00:34<00:06,  1.61it/s, loss=2.99, v_num=xp39, train_loss_step=2.820, lr_step=0.000364, val_loss_step=2.670, val_loss_epoch=3.200, val/F1_top1=0.0323, val/acc_top1=0[A
Validating:  23%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå                                                       

Epoch 7, global step 407: val_loss reached 2.96658 (best 2.96658), saving model to "/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/image_classification_train/2xf9xp39/checkpoints/checkpoint/epoch=07-val_loss=2.9666-val_f1=0.0000.ckpt" as top 1


Epoch 8:  81%|‚ñä| 52/64 [00:33<00:07,  1.60it/s, loss=2.76, v_num=xp39, train_loss_step=3.190, lr_step=0.000328, val_loss_step=2.640, val_loss_epoch=2.970, val/F1_top1=0.050, val/acc_top1=0.
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Epoch 8:  84%|‚ñä| 54/64 [00:35<00:06,  1.57it/s, loss=2.76, v_num=xp39, train_loss_step=3.190, lr_step=0.000328, val_loss_step=2.640, val_loss_epoch=2.970, val/F1_top1=0.050, val/acc_top1=0.[A
Epoch 8:  88%|‚ñâ| 56/64 [00:35<00:04,  1.62it/s, loss=2.76, v_num=xp39, train_loss_step=3.190, lr_step=0.000328, val_loss_step=2.640, val_loss_epoch=2.970, val/F1_top1=0.050, val/acc_top1=0.[A
Epoch 8:  91%|‚ñâ| 58/64 [00:36<00:03,  1.62it/s, loss=2.76, v_num=xp39, train_loss_step=3.190, lr_step=0.000328, val_loss_step=2.640, val_loss_epoch=2.970, val/F1_top1=0.050, val/acc_top1=0.

Epoch 8, global step 458: val_loss was not in top 1


Epoch 9:  81%|‚ñä| 52/64 [00:29<00:06,  1.79it/s, loss=2.71, v_num=xp39, train_loss_step=2.610, lr_step=0.00029, val_loss_step=2.650, val_loss_epoch=2.980, val/F1_top1=0.0508, val/acc_top1=0.
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                     | 0/13 [00:00<?, ?it/s][A
Validating:   8%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                                                                                                                                  | 1/13 [00:02<00:28,  2.35s/it][A
Epoch 9:  84%|‚ñä| 54/64 [00:32<00:05,  1.70it/s, loss=2.71, v_num=xp39, train_loss_step=2.610, lr_step=0.00029, val_loss_step=2.650, val_loss_epoch=2.980, val/F1_top1=0.0508, val/acc_top1=0.[A
Validating:  23%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå                                                       

Epoch 9, global step 509: val_loss was not in top 1


Epoch 9: 100%|‚ñà| 64/64 [00:39<00:00,  1.66it/s, loss=2.71, v_num=xp39, train_loss_step=2.610, lr_step=0.00029, val_loss_step=2.800, val_loss_epoch=3.030, val/F1_top1=0.038, val/acc_top1=0.0

Saving latest checkpoint...





  rank_zero_deprecation(


VBox(children=(Label(value=' 4.35MB of 4.35MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)‚Ä¶

0,1
train_loss_step,2.04558
lr_step,0.00029
epoch,9.0
trainer/global_step,509.0
_runtime,414.0
_timestamp,1630375943.0
_step,159.0
val_loss_step,2.80243
val_loss_epoch,3.02703
val/F1_top1,0.03804


0,1
train_loss_step,‚ñà‚ñá‚ñÜ‚ñÑ‚ñÜ‚ñÜ‚ñÉ‚ñÖ‚ñÉ‚ñÅ
lr_step,‚ñà‚ñà‚ñà‚ñá‚ñÜ‚ñÜ‚ñÖ‚ñÉ‚ñÇ‚ñÅ
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
trainer/global_step,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñà
_runtime,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà
_timestamp,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà
_step,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
val_loss_step,‚ñÑ‚ñÖ‚ñá‚ñà‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñÉ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÖ‚ñÖ‚ñÖ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÅ‚ñÉ‚ñÉ‚ñÑ‚ñÇ‚ñÖ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ
val_loss_epoch,‚ñà‚ñÜ‚ñÑ‚ñÖ‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ
val/F1_top1,‚ñÅ‚ñÇ‚ñÉ‚ñÇ‚ñÖ‚ñÖ‚ñÑ‚ñà‚ñà‚ñÖ


# Debugging

In [9]:
from lightning_hydra_classifiers.utils.common_utils import LabelEncoder
encoder = LabelEncoder()
experiment = TransferExperiment()

task_0 = experiment.setup_task_0()

task_0['train'].label_encoder

In [10]:
task_0['val'].label_encoder

<LabelEncoder>:
    num_classes: 93
    fit on num_samples: 16605

In [11]:
task_0['test'].label_encoder

<LabelEncoder>:
    num_classes: 20
    fit on num_samples: 2797

In [48]:
replace_class_indices = {"Nothofagaceae":"Fagaceae"}
task_0_label_encoder = task_0['train'].label_encoder
print(task_0_label_encoder)

task_0_label_encoder.__init__(replace = replace_class_indices)

print(task_0_label_encoder)

<LabelEncoder(num_classes=0)>
<num_replaced_classes=1>
<LabelEncoder(num_classes=0)>
<num_replaced_classes=1>


In [35]:
print(len(task_0_label_encoder.classes))
print(task_0_label_encoder)
task_0_label_encoder.fit(task_0['test'].targets)
print(len(task_0_label_encoder.classes))
print(task_0_label_encoder)
task_0_label_encoder.fit(task_0['train'].targets)
print(len(task_0_label_encoder.classes))
print(task_0_label_encoder)

In [58]:
print(task_0_label_encoder)



import collections

self = task_0_label_encoder
y = task_0['test'].targets

counts = collections.Counter(y)
print(self.num_samples)
self.num_samples += sum(counts.values())
print(self.num_samples)

classes = sorted(counts.keys())
print(classes)
print(len(classes))

In [53]:
# old_num_classes = len(self)
# print(f"old_num_classes={old_num_classes}")
# new_classes = sorted([label for label in classes if label not in self.classes])
# print(f'new_classes={new_classes}')
# print(f'len(new_classes)={len(new_classes)}')

old_num_classes = len(self)
print(f"old_num_classes={old_num_classes}")
new_classes = []
for label in classes:
    if (label not in self.classes) and (label not in self.replace):
        new_classes.append(label)
# new_classes = sorted([label for label in classes if label not in self.classes])
print(f'new_classes={new_classes}')
print(f'len(new_classes)={len(new_classes)}')

In [61]:
for i, label in enumerate(new_classes):
    self.class2idx[label] = old_num_classes + i
print(self.class2idx)
print(len(self.class2idx))

self.index2class = {v: k for k, v in self.class2idx.items()}

all_classes = []
for label in self.class2idx.keys():
    if label not in self.replace.keys():
        all_classes.append(label)
print(f"all_classes={all_classes}")
print(f"len(all_classes)={len(all_classes)}")

self.classes = sorted(all_classes)

In [64]:
# self.classes = [k for k in self.class2idx.keys() if k not in self.replace.keys()] 
print(len(self.class2idx), len(self.index2class), len(self.classes))

self.replace_class2idx_items()
print(len(self.class2idx), len(self.index2class), len(self.classes))

new_classes = [c for c in new_classes if c not in self.replace.keys()]
if len(new_classes):
    log.debug(f"[FITTING] {len(y)} samples with {len(classes)} classes, adding {len(new_classes)} new class labels. Latest num_classes = {len(self)}")
assert len(self) == (old_num_classes + len(new_classes))

In [None]:
self.index2class = {v: k for k, v in self.class2idx.items()}

# all_classes = []

self.classes = [k for k in self.class2idx.keys() if k not in self.replace.keys()]        
self.replace_class2idx_items()

new_classes = [c for c in new_classes if c not in self.replace.keys()]
if len(new_classes):
    log.debug(f"[FITTING] {len(y)} samples with {len(classes)} classes, adding {len(new_classes)} new class labels. Latest num_classes = {len(self)}")
assert len(self) == (old_num_classes + len(new_classes))

(task_0['test'].label_encoder)

(task_0['test'].classes)



sorted(set(datamodule.train_dataset.targets))

len(sorted(set(datamodule.train_dataset.classes)))


len(sorted(set(datamodule.train_dataset.classes)))


datamodule.num_classes

import numpy as np
set(np.arange(len(datamodule.train_dataset.classes))) - set(datamodule.train_dataset.targets)

import numpy as np
set(datamodule.train_dataset.targets) - set(np.arange(len(datamodule.train_dataset.classes)))


len(sorted(set(datamodule.train_dataset.targets)))

model.metrics_train

model.metrics_val

## Export experiment

In [1]:
encoder_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/experiments_August_2021/Extant-to-PNAS-512-transfer_benchmark/task_0/test.json"

from lightning_hydra_classifiers.utils.common_utils import LabelEncoder




In [2]:
encoder = LabelEncoder.load(encoder_path)

encoder


<LabelEncoder>:
    num_classes: 177
    fit on num_samples: 0

In [4]:
len(encoder.class2idx)

len(encoder.index2class)

177

In [5]:
set(encoder.class2idx.keys()) - set(encoder.index2class.values())

{'Fagaceae', 'Phyllanthaceae'}

```
LabelEncoder:  
    ::class2idx  
        - Needs to have a many-to-one mapping from text to int. All int labels are unique, but multiple class names can map to 1 int label.
    ::idx2class  
        - Needs to have a one-to-one mapping from int to text. For ints that map to more than 1 text label, this maps the int only to the correct standardized label for the experiment. As in, it maps only to the label we used to replace another label. e.g. If replacements includes {"NothoFagaceae": "Fagaceae"}, and both  of them maps to int label 16, then encoder.idx2class[16] must return "Fagaceae".
    ::classes  
    ::num_classes  
        - Needs to correspond to actual neural net output size, therefore excludes replaced classes
    ::replacements  
```

* len(idx2class) <= len(class2idx)
* num_classes == len(idx2class) <= len(class2idx)



1. Initialize blank label encoder
    - Provide replacements dict to allow backwards compatible mappings
        - e.g. Nothofagaceae (newer, Extant) -> Fagaceae (older, PNAS)

2. Fit encoder on $|y_s|$ to have 

In [1]:
from lightning_hydra_classifiers.experiments.transfer_experiment import TransferExperiment
output_root_dir = "/media/data_cifs/projects/prj_fossils/users/jacob/experiments/July2021-Nov2021/csv_datasets/experimental_datasets"
experiment = TransferExperiment()
experiment.export_experiment_spec(output_root_dir=output_root_dir)

Exporting experiment to experiment_dir: /media/data_cifs/projects/prj_fossils/users/jacob/experiments/July2021-Nov2021/csv_datasets/experimental_datasets/Extant-to-PNAS-512-transfer_benchmark
train 16605
val 4152
train 2011
val 503
19
92


In [6]:
self = experiment

replace_class_indices = {"Nothofagaceae":"Fagaceae"}

task_0 = self.setup_task_0()
task_1 = self.setup_task_1()

#         import pdb;pdb.set_trace()



#         print(f"__init__: {task_0['train'].label_encoder}")
task_0_label_encoder = task_0['train'].label_encoder
task_0_label_encoder.__init__(replacements = replace_class_indices)

# print(max(task_0_label_encoder.class2idx.values()))
task_0_label_encoder.class2idx

print(len(set(task_0['test'].targets)))





print(len(set(task_0_label_encoder.classes)))
task_0_label_encoder.fit(task_0['test'].targets)
print(len(set(task_0_label_encoder.classes)))

In [7]:
print(max(task_0_label_encoder.class2idx.values()))
task_0_label_encoder.class2idx

18


{'Anacardiaceae': 0,
 'Annonaceae': 1,
 'Apocynaceae': 2,
 'Betulaceae': 3,
 'Celastraceae': 4,
 'Combretaceae': 5,
 'Ericaceae': 6,
 'Fabaceae': 7,
 'Fagaceae': 8,
 'Lauraceae': 9,
 'Malvaceae': 10,
 'Melastomataceae': 11,
 'Myrtaceae': 12,
 'Passifloraceae': 13,
 'Phyllanthaceae': 14,
 'Rosaceae': 15,
 'Rubiaceae': 16,
 'Salicaceae': 17,
 'Sapindaceae': 18,
 'Nothofagaceae': 8}

In [8]:
print(len(set(task_0['train'].targets)))

93


In [None]:


print(len(set(task_0_label_encoder.classes)))
task_0_label_encoder.fit(task_0['train'].targets)
print(len(set(task_0_label_encoder.classes)))

In [10]:


print(max(task_0_label_encoder.class2idx.values()))
task_0_label_encoder.class2idx

In [11]:
task_0_label_encoder.classes

task_0_label_encoder.idx2class

In [3]:
!rm -r "/media/data_cifs/projects/prj_fossils/users/jacob/experiments/July2021-Nov2021/csv_datasets/experimental_datasets/task_1"

In [2]:
class self:
    idx2class = {0:"test",
                 1:"2"}



old_highest_class = max(self.idx2class.keys())

In [3]:
old_highest_class

1

# Refactor LabelEncoder

In [13]:
import os
from pathlib import Path
import pandas as pd
import numpy as np
import numbers
from typing import Union, List, Any, Tuple, Dict, Optional, Sequence
import collections
from sklearn.model_selection import train_test_split
import json
from lightning_hydra_classifiers.utils import template_utils
from lightning_hydra_classifiers.utils.plot_utils import colorbar


log = template_utils.get_logger(__name__)


# __all__ = ["LabelEncoder", "trainval_split", "trainvaltest_split", "plot_split_distributions", "plot_class_distributions",
#            "filter_df_by_threshold", "compute_class_counts"]



class LabelEncoder(object):
    
    """Label encoder for tag labels."""
    def __init__(self,
                 class2idx: Dict[str,int]=None,
                 replacements: Optional[Dict[str,str]]=None):
        self.class2idx = class2idx or {}
        self.replacements = replacements or {}
#         self.idx2class = {v: k for k, v in self.class2idx.items() if k not in self.replacements.keys()}
#         self.classes = [k for k in self.class2idx.keys() if k not in self.replacements.keys()]
        
        assert len(self.classes) == len(self.idx2class) <= len(self.class2idx)
        self.num_samples = 0
        self.verbose=False
        self.replace_class2idx_items()
        

    @property
    def idx2class(self):
        return {v: k for k, v in self.class2idx.items() if k not in self.replacements.keys()}
    
    @property
    def classes(self):
        return [k for k in self.class2idx.keys() if k not in self.replacements.keys()]
    

        
    def replace_class2idx_items(self):
        """
        Update inplace self.class2idx mappings, so that any class labels in self.replacements.keys()
        map to the same int label as their corresponding value in self.replacements.values().
        
        """
        if (len(self.replacements) == 0) \
        or (len([k for k in self.replacements.keys() if k in self.class2idx.keys()]) == 0):
            # No-op if replacements keys are empty or have zero overlap with class2idx keys.
            return
        
        if self.verbose:
            log.info(f'LabelEncoder replacing {len(self.replacements.keys())} class encodings with that other an another class')
            log.info('Replacing: ' + str({k:v for k,v in self.replacements.items() if k in self.class2idx}))
        for old, new in self.replacements.items():
            if old in list(self.class2idx.keys()):
                self.class2idx[old] = self.class2idx[new]
#         self.idx2class = {v: k for k, v in self.class2idx.items()}
#         self.classes = [k for k in self.class2idx.keys() if k not in self.replacements.keys()]                
        
    def __len__(self):
        return len(self.idx2class)
#         return len(self.classes)

    def num_classes(self):
        return len(self)


    def __str__(self):
        msg = f"<LabelEncoder(num_classes={len(self)})>"
        if len(self.replacements) > 0:
            msg += "\n" + f"<num_replaced_classes={len(self.replacements)}>"
        return msg

    def fit(self, y):
        
        counts = collections.Counter(y)
        self.num_samples += sum(counts.values())
        
        classes = sorted(list(counts.keys()))
        new_classes = sorted([label for label in classes if label not in self.classes])
        
        old_num_classes = len(self)
        old_highest_class = max(self.idx2class.keys())
        for i, label in enumerate(new_classes):
            self.class2idx[label] = old_highest_class + i
#         self.idx2class = {v: k for k, v in self.class2idx.items()}
#         self.classes = [k for k in self.class2idx.keys() if k not in self.replacements.keys()]        
        self.replace_class2idx_items()

        new_classes = [c for c in new_classes if c not in self.replacements.keys()]
        if len(new_classes):
            log.debug(f"[FITTING] {len(y)} samples with {len(classes)} classes, adding {len(new_classes)} new class labels. Latest num_classes = {len(self)}")
        assert len(self) == (old_num_classes + len(new_classes))
        assert np.all([label in self.idx2class.values() for label in new_classes])
        return self

    def encode(self, y):
        if not hasattr(y,"__len__"):
            y = [y]
#         print(self.class2idx)
        return np.array([self.class2idx[label] for label in y])

    def decode(self, y):
        if not hasattr(y,"__len__"):
            y = [y]
        return np.array([self.idx2class[label] for label in y])

    def save(self, fp):
        with open(fp, "w") as fp:
            contents = self.getstate() # {"class2idx": self.class2idx}
            json.dump(contents, fp, indent=4, sort_keys=False)

    @classmethod
    def load(cls, fp):
        with open(fp, "r") as fp:
            kwargs = json.load(fp=fp)
        return cls(**kwargs)
    
    def getstate(self):
        return {"class2idx": self.class2idx,
                "replacements": self.replacements}
    
    def __repr__(self):
        disp = f"""<{str(type(self)).strip("'>").split('.')[-1]}>:\n"""
        disp += f"    num_classes: {len(self)}\n"
        disp += f"    fit on num_samples: {self.num_samples}"
        return disp

In [15]:
from lightning_hydra_classifiers.utils.common_utils import LabelEncoder
encoder = LabelEncoder()
experiment = TransferExperiment()
task_0 = experiment.setup_task_0()

old_encoder = task_0['train'].label_encoder

train 16605
val 4152


In [16]:
encoder = LabelEncoder()

print(f"Old:", old_encoder)
print("New, initialized:", encoder)

Old: <LabelEncoder(num_classes=93)>
New, initialized: <LabelEncoder(num_classes=0)>


In [17]:
data = task_0['train']
data_df = data.samples_df

In [29]:
experiment

replace_class_indices = {"Nothofagaceae":"Fagaceae"}

(data.label_encoder.replace)

data_df

In [30]:
data_df = data_df.rename(columns={"family":"newest_family"})
data_df = data_df.assign(family = data_df.newest_family.replace(replace_class_indices))
data_df


Unnamed: 0,path,newest_family,genus,species,collection,catalog_number,family
0,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Cannabaceae/Cannabaceae_Celtis_biondii_Wolfe_Wolfe_8999b.jpg,Cannabaceae,Celtis,biondii,Wolfe,Wolfe_8999b,Cannabaceae
1,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Sapindaceae/Sapindaceae_Serjania_acutidentata_Wing_Wing_597-003b.jpg,Sapindaceae,Serjania,acutidentata,Wing,Wing_597-003b,Sapindaceae
2,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Fabaceae/Fabaceae_Pithecellobium_lasiopus_Hickey_Hickey_4124.jpg,Fabaceae,Pithecellobium,lasiopus,Hickey,Hickey_4124,Fabaceae
3,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Rutaceae/Rutaceae_Clausena_heptaphylla_Hickey_Hickey_6081.jpg,Rutaceae,Clausena,heptaphylla,Hickey,Hickey_6081,Rutaceae
4,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Bignoniaceae/Bignoniaceae_Dolichandrone_cauda-felina_Wolfe_Wolfe_2413a.jpg,Bignoniaceae,Dolichandrone,cauda-felina,Wolfe,Wolfe_2413a,Bignoniaceae
...,...,...,...,...,...,...,...
16600,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Rosaceae/Rosaceae_Prunus_subcordata_Axelrod_Axelrod_387.jpg,Rosaceae,Prunus,subcordata,Axelrod,Axelrod_387,Rosaceae
16601,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Celastraceae/Celastraceae_Celastrus_paniculatus_Hickey_Hickey_4380.jpg,Celastraceae,Celastrus,paniculatus,Hickey,Hickey_4380,Celastraceae
16602,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Acanthaceae/Acanthaceae_Staurogyne_anigozanthus_Hickey_Hickey_1238_1.jpg,Acanthaceae,Staurogyne,anigozanthus,Hickey,Hickey_1238_1,Acanthaceae
16603,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Lauraceae/Lauraceae_Stemmatodaphne_perakensis_Wolfe_Wolfe_15508a.jpg,Lauraceae,Stemmatodaphne,perakensis,Wolfe,Wolfe_15508a,Lauraceae


In [35]:
# data_df = data_df.sort_values("family").astype({"family":pd.CategoricalDtype(),
#                           "newest_family":pd.CategoricalDtype()}) #family.cat #!=data_df.newest_family]

<class 'pandas.core.frame.DataFrame'>
Int64Index: 16605 entries, 0 to 16604
Data columns (total 7 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   path            16605 non-null  string
 1   newest_family   16605 non-null  string
 2   genus           16605 non-null  string
 3   species         16605 non-null  string
 4   collection      16605 non-null  string
 5   catalog_number  16605 non-null  string
 6   family          16605 non-null  object
dtypes: object(1), string(6)
memory usage: 1.0+ MB


In [64]:
data_df = data_df.sort_values("family").astype({"family":pd.CategoricalDtype()})
data_df = data_df.convert_dtypes()
data_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 16605 entries, 16111 to 8302
Data columns (total 7 columns):
 #   Column          Non-Null Count  Dtype   
---  ------          --------------  -----   
 0   path            16605 non-null  string  
 1   newest_family   16605 non-null  category
 2   genus           16605 non-null  string  
 3   species         16605 non-null  string  
 4   collection      16605 non-null  string  
 5   catalog_number  16605 non-null  string  
 6   family          16605 non-null  category
dtypes: category(2), string(5)
memory usage: 816.3 KB


In [69]:
dir(data_df.family[~data_df.family.duplicated(keep='first')])
# data_df.family.duplicated(keep='first')

['T',
 '_AXIS_LEN',
 '_AXIS_ORDERS',
 '_AXIS_REVERSED',
 '_AXIS_TO_AXIS_NUMBER',
 '_HANDLED_TYPES',
 '__abs__',
 '__add__',
 '__and__',
 '__annotations__',
 '__array__',
 '__array_priority__',
 '__array_ufunc__',
 '__array_wrap__',
 '__bool__',
 '__class__',
 '__contains__',
 '__copy__',
 '__deepcopy__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__divmod__',
 '__doc__',
 '__eq__',
 '__finalize__',
 '__float__',
 '__floordiv__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__iadd__',
 '__iand__',
 '__ifloordiv__',
 '__imod__',
 '__imul__',
 '__init__',
 '__init_subclass__',
 '__int__',
 '__invert__',
 '__ior__',
 '__ipow__',
 '__isub__',
 '__iter__',
 '__itruediv__',
 '__ixor__',
 '__le__',
 '__len__',
 '__long__',
 '__lt__',
 '__matmul__',
 '__mod__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__nonzero__',
 '__or__',
 '__pos__',
 '__pow__',
 '__radd__',
 '__rand__',
 '__r

In [72]:
y_col = "family"
category_df = data_df[y_col][~data_df[y_col].duplicated(keep='first')]   #.to_list()
class2index = dict(zip(category_df.to_list(), category_df.cat.codes.to_list()))
# data_df.family.duplicated(keep='first')

In [73]:
class2index

{'Acanthaceae': 0,
 'Achariaceae': 1,
 'Actinidiaceae': 2,
 'Altingiaceae': 3,
 'Amaranthaceae': 4,
 'Anacardiaceae': 5,
 'Annonaceae': 6,
 'Apiaceae': 7,
 'Apocynaceae': 8,
 'Aquifoliaceae': 9,
 'Araliaceae': 10,
 'Asteraceae': 11,
 'Berberidaceae': 12,
 'Betulaceae': 13,
 'Bignoniaceae': 14,
 'Burseraceae': 15,
 'Cannabaceae': 16,
 'Capparaceae': 17,
 'Caprifoliaceae': 18,
 'Celastraceae': 19,
 'Chloranthaceae': 20,
 'Chrysobalanaceae': 21,
 'Clusiaceae': 22,
 'Combretaceae': 23,
 'Connaraceae': 24,
 'Cornaceae': 25,
 'Crassulaceae': 26,
 'Cunoniaceae': 27,
 'Dilleniaceae': 28,
 'Dipterocarpaceae': 29,
 'Ebenaceae': 30,
 'Elaeocarpaceae': 31,
 'Ericaceae': 32,
 'Euphorbiaceae': 33,
 'Fabaceae': 34,
 'Fagaceae': 35,
 'Grossulariaceae': 36,
 'Hamamelidaceae': 37,
 'Hydrangeaceae': 38,
 'Icacinaceae': 39,
 'Juglandaceae': 40,
 'Lamiaceae': 41,
 'Lauraceae': 42,
 'Lecythidaceae': 43,
 'Loranthaceae': 44,
 'Lythraceae': 45,
 'Magnoliaceae': 46,
 'Malpighiaceae': 47,
 'Malvaceae': 48,
 'Ma

In [37]:
data_df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 16605 entries, 0 to 16604
Data columns (total 7 columns):
 #   Column          Non-Null Count  Dtype   
---  ------          --------------  -----   
 0   path            16605 non-null  string  
 1   newest_family   16605 non-null  string  
 2   genus           16605 non-null  string  
 3   species         16605 non-null  string  
 4   collection      16605 non-null  string  
 5   catalog_number  16605 non-null  string  
 6   family          16605 non-null  category
dtypes: category(1), string(6)
memory usage: 927.1 KB


In [43]:
dir(data_df.family.cat)

['__annotations__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__frozen',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_accessors',
 '_add_delegate_accessors',
 '_constructor',
 '_delegate_method',
 '_delegate_property_get',
 '_delegate_property_set',
 '_dir_additions',
 '_dir_deletions',
 '_freeze',
 '_hidden_attrs',
 '_index',
 '_name',
 '_parent',
 '_reset_cache',
 '_validate',
 'add_categories',
 'as_ordered',
 'as_unordered',
 'categories',
 'codes',
 'ordered',
 'remove_categories',
 'remove_unused_categories',
 'rename_categories',
 'reorder_categories',
 'set_categories']

In [50]:
len(set(data_df.family.cat.codes))

92

In [44]:
data_df.family.cat.ordered

False

In [None]:
data_df.family.cat.codes

In [33]:
data_df[data_df.family!=data_df.newest_family]

Unnamed: 0,path,newest_family,genus,species,collection,catalog_number,family
533,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_macrophylla_Hickey_Hickey_6169.jpg,Nothofagaceae,Nothofagus,macrophylla,Hickey,Hickey_6169,Fagaceae
668,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_grandis_Hickey_Hickey_1766.jpg,Nothofagaceae,Nothofagus,grandis,Hickey,Hickey_1766,Fagaceae
1237,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Trisyngyne_discoidea_Hickey_Hickey_718.jpg,Nothofagaceae,Trisyngyne,discoidea,Hickey,Hickey_718,Fagaceae
1277,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_discoidea_Wolfe_Wolfe_8533.jpg,Nothofagaceae,Nothofagus,discoidea,Wolfe,Wolfe_8533,Fagaceae
1685,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_aequilateralis_Wolfe_Wolfe_8530.jpg,Nothofagaceae,Nothofagus,aequilateralis,Wolfe,Wolfe_8530,Fagaceae
3066,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_nitida_Hickey_Hickey_1781.jpg,Nothofagaceae,Nothofagus,nitida,Hickey,Hickey_1781,Fagaceae
3294,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_discoidea_Hickey_Hickey_6432.jpg,Nothofagaceae,Nothofagus,discoidea,Hickey,Hickey_6432,Fagaceae
4482,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_perryi_Wolfe_Wolfe_8534.jpg,Nothofagaceae,Nothofagus,perryi,Wolfe,Wolfe_8534,Fagaceae
5272,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_gunni_Hickey_Hickey_1773.jpg,Nothofagaceae,Nothofagus,gunni,Hickey,Hickey_1773,Fagaceae
5283,/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_0/images/Extant_Leaves/512/50/jpg/Nothofagaceae/Nothofagaceae_Nothofagus_fusca_Wolfe_Wolfe_3218.jpg,Nothofagaceae,Nothofagus,fusca,Wolfe,Wolfe_3218,Fagaceae


# BYOL

## Load and preprocess pre-formatted csv datasets and create train val test splits

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

from torchvision.models import mobilenet_v2, resnet50
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor
from torchvision import transforms
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import os

import torch.nn as nn
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from lightning_hydra_classifiers.train_BYOL import *
torch.backends.cudnn.benchmark = True
from munch import Munch

In [2]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

if 'TOY_DATA_DIR' not in os.environ: 
    print(f"Setting env variable $TOY_DATA_DIR={os.environ['TOY_DATA_DIR']}")
    os.environ['TOY_DATA_DIR'] = "/media/data_cifs/projects/prj_fossils/data/toy_data"
default_root_dir = os.environ['TOY_DATA_DIR']
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"



config = Munch({"dataset_name":"Extant-PNAS",
                "model":{
                    "backbone":"resnet50"}
               })
# config = Munch({"dataset_name":"STL10"})
print(f"config: {config}")


transform = transforms.Compose([ToTensor(),
                               normalize])



if config.dataset_name == "STL10":
    from torchbearer.cv_utils import DatasetValidationSplitter

    train_data = STL10(os.environ['TOY_DATA_DIR'], split='train', transform=transform, download=True)
    test_data = STL10(os.environ['TOY_DATA_DIR'], split='test', transform=transform, download=True)
    
    splitter = DatasetValidationSplitter(len(train_data), 0.1)
    train_set = splitter.get_train_dataset(train_data)
    val_set = splitter.get_val_dataset(train_data)
    
else:
    exp = TransferExperiment()
    task_0, task_1 = exp.get_multitask_datasets(train_transform=transform,
                                                val_transform=transform)
    train_data, val_data, test_data = task_0["train"], task_0["val"], task_0["test"]
    train_set, val_set = train_data, val_data

classes = train_data.classes
num_classes = len(classes)
print('\n List of all classes: ')
print(classes)
print(f"len(classes)={len(classes)}")


BATCH_SIZE = 16
num_workers = 2
pin_memory = False

train_gen = torch.utils.data.DataLoader(train_set, pin_memory=pin_memory, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
val_gen = torch.utils.data.DataLoader(val_set, pin_memory=pin_memory, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)
test_gen = torch.utils.data.DataLoader(test_data, pin_memory=pin_memory, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers)

config: Munch({'dataset_name': 'Extant-PNAS', 'model': {'backbone': 'resnet50'}})

 List of all classes: 
['Fagaceae', 'Ericaceae', 'Fabaceae', 'Anacardiaceae', 'Rosaceae', 'Betulaceae', 'Salicaceae', 'Sapindaceae', 'Lauraceae', 'Rubiaceae', 'Celastraceae', 'Malvaceae', 'Myrtaceae', 'Apocynaceae', 'Melastomataceae', 'Passifloraceae', 'Combretaceae', 'Annonaceae', 'Phyllanthaceae', 'Clusiaceae', 'Sapotaceae', 'Lythraceae', 'Burseraceae', 'Bignoniaceae', 'Meliaceae', 'Malpighiaceae', 'Cunoniaceae', 'Onagraceae', 'Marantaceae', 'Santalaceae', 'Berberidaceae', 'Ochnaceae', 'Ebenaceae', 'Cornaceae', 'Sabiaceae', 'Schisandraceae', 'Araliaceae', 'Staphyleaceae', 'Hamamelidaceae', 'Dipterocarpaceae', 'Dichapetalaceae', 'Dilleniaceae', 'Proteaceae', 'Connaraceae', 'Caprifoliaceae', 'Piperaceae', 'Bonnetiaceae', 'Juglandaceae', 'Geraniaceae', 'Aquifoliaceae', 'Moraceae', 'Rutaceae', 'Cercidiphyllaceae', 'Orchidaceae', 'Orobanchaceae', 'Apiaceae', 'Violaceae', 'Altingiaceae', 'Eucommiaceae', 'Gne

In [4]:
# cm = PyCM().on_val().to_html_file('cm.{epoch}')

# We copy the final layer form MobileNetV2 and replace the linear layer with one to 10 channels

if config.model.backbone == "mobilenet_v2":
    model = mobilenet_v2(pretrained=True, progress=False)
    model.classifier = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(model.last_channel, num_classes),
            )

elif config.model.backbone == "resnet50":
    model = resnet50(pretrained=True, progress=False)
    model.fc = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(model.fc.in_features, num_classes+1),
            )

In [11]:
# import torchsummary
# import torchinfo
import torch.optim as optim
# import torchbearer
# from torchbearer import Trial
# from torchbearer.callbacks import PyCM

# for k, m in model.named_modules():
#     if (k.startswith("features")) or (k.startswith("layer")):
#         print(f"Freezing: {k}")
#         m.requires_grad = False
#     else:
#         print(f"Unfreezing: {k}")
#         m.requires_grad = True
        
#     print(f"{k} : {m.requires_grad}")

freeze_at = "fc"

freeze_current=False
for k, m in model.named_modules():
    if (k == freeze_at) or freeze_current:
        freeze_current=True
        print(f"Unfreezing: {k}")
        m.requires_grad = True
    else:
        print(f"Freezing: {k}")
        m.requires_grad = False
        
    print(f"{k} : {m.requires_grad}")

In [None]:
# cm = PyCM().on_val().to_pyplot(normalize=True, title='Confusion Matrix: {epoch}')
# cm_csv = PyCM().on_val().to_csv_file("cm_{epoch}")
# model = SimpleModel()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
loss = nn.CrossEntropyLoss()

# trial = Trial(model, optimizer, loss, metrics=['acc', 'loss'], callbacks=[cm,cm_csv]).to(device)
# trial.with_generators(train_generator=train_gen, val_generator=val_gen)
# history = trial.run(epochs=2, verbose=2)

In [None]:
model = SelfSupervisedLearner(
    resnet,
    image_size = IMAGE_SIZE,
    hidden_layer = 'avgpool',
    projection_size = 256,
    projection_hidden_size = 4096,
    moving_average_decay = 0.99
)

trainer = pl.Trainer(
    gpus = NUM_GPUS,
    max_epochs = EPOCHS,
    accumulate_grad_batches = 1,
    sync_batchnorm = True
)

trainer.fit(model, train_loader)

In [3]:
import torchdata

# class UnsupervisedDatasetWrapper(torchdata.datasets.Files):
class UnsupervisedDatasetWrapper(torchvision.datasets.ImageFolder):
    
    def __init__(self, dataset):
        
        self.dataset = dataset
        
    def __getitem__(self, index):
        return self.dataset[index][0]
    
    def __len__(self):
        return len(self.dataset)
    
    def __repr__(self):
        out = "<UnsupervisedDatasetWrapper>\n"
        out += self.dataset.__repr__()
        return out

In [8]:
task_0['test'][0]

for subset in ["train","val","test"]:
    task_0[subset] = UnsupervisedDatasetWrapper(task_0[subset])
    task_1[subset] = UnsupervisedDatasetWrapper(task_1[subset])

task_0['test'][0]

type(task_0['test'].dataset)

task_0['test']#.dataset

### Create transforms

In [14]:
from torchvision import transforms
from typing import *


totensor: Callable = torchvision.transforms.ToTensor()

def toPIL(img: torch.Tensor, mode="RGB") -> Callable:
    return torchvision.transforms.ToPILImage(mode)


def normalize_transform(mean = [0.485, 0.456, 0.406],
                        std = [0.229, 0.224, 0.225]) -> Callable:
    return transforms.Normalize(mean=mean,
                                std=std)

def default_train_transforms(image_size: int=224,
                             normalize: bool=True, 
                             augment:bool=True,
                             grayscale: bool=True,
                             channels: Optional[int]=3,
                             mean = [0.485, 0.456, 0.406],
                             std = [0.229, 0.224, 0.225]):
    """Subclasses can override this or user can provide custom transforms at runtime"""
    transform_list = []
#         transform_jit_list = []
    resize_PIL = not augment
    if augment:
        transform_list.extend([transforms.RandomResizedCrop(size=image_size,
                                                            scale=(0.25, 1.2),
                                                            ratio=(0.7, 1.3),
                                                            interpolation=2),
                               totensor
                             ])
    return default_eval_transforms(image_size=image_size,
                                        normalize=normalize,
                                        resize_PIL=resize_PIL,
                                        grayscale=grayscale,
                                        channels=channels,
                                        transform_list=transform_list,
                                        mean=mean,
                                        std=std)

def default_eval_transforms(image_size: int=224,
                            image_buffer_size: int=32,
                            normalize: bool=True,
                            resize_PIL: bool=True,
                            grayscale: bool=True,
                            channels: Optional[int]=3,
                            transform_list: Optional[List[Callable]]=None,
                            mean = [0.485, 0.456, 0.406],
                            std = [0.229, 0.224, 0.225]):
    """Subclasses can override this or user can provide custom transforms at runtime"""
    transform_list = transform_list or []
    transform_jit_list = []

    if resize_PIL:
        # if True, assumes input images are PIL.Images (But need to check if this even matters.)
        # if False, expects input images to already be torch.Tensors
        transform_list.extend([transforms.Resize(image_size+image_buffer_size),
                               transforms.CenterCrop(image_size),
                               totensor])
    if normalize:
        transform_jit_list.append(normalize_transform(mean, std))

    if grayscale:
        transform_jit_list.append(transforms.Grayscale(num_output_channels=channels))

    return transforms.Compose([*transform_list, *transform_jit_list])


def get_default_transforms(image_size: int=224,
                           normalize: bool=True,
                           augment:bool=True,
                           grayscale: bool=True,
                           channels: Optional[int]=3,
                           mean = [0.485, 0.456, 0.406],
                           std = [0.229, 0.224, 0.225]):

    
    train_transform = default_train_transforms(image_size=image_size,
                                               normalize=normalize,
                                               augment=augment,
                                               grayscale=grayscale,
                                               channels=channels,
                                               mean=mean,
                                               std=std)
    eval_transform = default_eval_transforms(image_size=image_size,
                                             image_buffer_size=32,
                                             normalize=normalize,
                                             resize_PIL=not augment,
                                             grayscale=grayscale,
                                             channels=channels,
                                             transform_list=None,
                                             mean=mean,
                                             std=std)
    
    
    
    return train_transform, eval_transform

In [15]:
train_transform, val_transform = get_default_transforms(image_size=224,
                                                         normalize=True,
                                                         augment=True,
                                                         grayscale=True,
                                                         channels=3,
                                                         mean = [0.485, 0.456, 0.406],
                                                         std = [0.229, 0.224, 0.225])