In [2]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
import os
import yaml
from pathlib import Path
from PIL import Image

from SeaAnimalsDataset import SeaAnimalsDataset

In [3]:
path = "./archive/"

In [4]:
argv = "./baseline_config.yaml"

In [5]:
config = yaml.safe_load(open(argv, "r"))

In [6]:
config

{'model': {'name': 'Baseline'}, 'augmentation': {'name': '00'}}

In [7]:
class BaseLineModel(nn.Module):
    def __init__(self):
        pass
class BestCNN(nn.Module):
    def __init__(self):
        pass
class Transformer(nn.Module):
    def __init__(self):
        pass

def create_model(config: dict):
    if config["model"]["name"] == "Baseline":
        model = BaseLineModel()
    elif config["model"]["name"] == "BestCNN":
        model = BestCNN()
    else:
        model = Transformer()
    
    return model

def get_transform():
    transform = transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.ToTensor(),  # Shape: HWC, Scales data into [0,1] by div / 255
            transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) # image_channel = (image - mean) / std
        ]
    )

    return transform

def create_dataloaders(config: dict):
    augmentations = []
    if config["augmentation"]["name"] == 'all':
        augmentations = ['00', '01', '02', '03', '04', '05', '06']
    elif not config["augmentation"]["name"]:
        augmentations = ['00']
    else:
        augmentations.append(config["augmentation"]["name"])
        
    transformation = get_transform() # get transformations on the image
    
    sea_animals_train = SeaAnimalsDataset(
                                        img_path="train_augment",
                                        transform=transformation,
                                        train=True,
                                        augmentations=augmentations)
    sea_animals_val = SeaAnimalsDataset(
                                        img_path="val",
                                        transform=transformation,
                                        train=False)
    sea_animals_test = SeaAnimalsDataset(
                                        img_path="test",
                                        transform=transformation,
                                        train=False)
    train_loader = DataLoader(sea_animals_train, batch_size=1, shuffle=True, drop_last=False)
    val_loader = DataLoader(sea_animals_val, batch_size=1, shuffle=True, drop_last=False)
    test_loader = DataLoader(sea_animals_test, batch_size=1, shuffle=True, drop_last=False)
    
    return train_loader, val_loader, test_loader
    

def make(config: dict):
    """Return model and dataloader with/without data augmentation"""
    
    model = create_model(config) # get model
    train_loader, val_loader, test_loader = create_dataloaders(config)
        
    return model, train_loader, val_loader, test_loader
        


model, train_loader, val_loader, test_loader = make(config)

In [8]:
transformation = get_transform() # get transformations on the image

sea_animals_train = SeaAnimalsDataset(
                                        img_path="train_augment",
                                        transform=transformation,
                                        train=True,
                                        augmentations=["00"])

In [347]:
augmentations = ['00']
train_imgs_path = 'train_augment'
sea_animals_train = SeaAnimalsDataset(
        img_path=train_imgs_path,
        transform=transformation,
        train=True,
        augmentations=augmentations)
sea_animals_val = SeaAnimalsDataset(
    img_path="val",
    transform=transformation,
    train=False)
sea_animals_test = SeaAnimalsDataset(
    img_path="test",
    transform=transformation,
    train=False)

In [358]:
train_loader = DataLoader(sea_animals_train, batch_size=1, shuffle=True, drop_last=False)
val_loader = DataLoader(sea_animals_val, batch_size=1, shuffle=True, drop_last=False)
test_loader = DataLoader(sea_animals_test, batch_size=1, shuffle=True, drop_last=False)

In [359]:
type(train_loader), type(val_loader), type(test_loader)

(torch.utils.data.dataloader.DataLoader,
 torch.utils.data.dataloader.DataLoader,
 torch.utils.data.dataloader.DataLoader)

In [363]:
total = sum([len(train_loader), len(val_loader), len(test_loader)])
print(total)
print(len(train_loader), len(val_loader), len(test_loader))
print(len(train_loader)/total, len(val_loader)/total, len(test_loader)/total)

11742
8242 2345 1155
0.7019247146993698 0.19971044115142225 0.09836484414920797


In [None]:
# model_1
# plot of f1 on train-val data
# plot of f1 on train-val data with augmentations
# average f1 on hold-out data set 
# model_2
# transformer