In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
# move into project directory
repo_name = "flower-fgvc"
%cd /content/drive/MyDrive/Personal-Projects/$repo_name
!ls

/content/drive/MyDrive/Personal-Projects/flower-fgvc
common	     data	experiments  index.py  README.md
config.yaml  datautils	Index.ipynb  models    run.yaml


In [3]:
# set up environment
# comment out if not required
'''
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install matplotlib numpy pandas pyyaml opencv-python
'''

#!pip install transformers


'\n!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n!pip install matplotlib numpy pandas pyyaml opencv-python\n'

In [4]:
# this cell is for downloading data.
# as of yet data is not hosted and is available in the private data folder

#!tar xf data/102flowers.tgz -C data/

In [5]:
#set up some imports

import numpy as np
import torch
import random
from torchvision import transforms

# custom imports

from common.utils import init_config, get_exp_params
from datautils.dataset import FlowerDataset
from datautils.datareader import get_file_paths

In [6]:
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [7]:
config_params = init_config()
print('nb', config_params)

nb {'data_dir': '/content/drive/MyDrive/Personal-Projects/flower-fgvc/data', 'device': 'cpu', 'output_dir': '/content/drive/MyDrive/Personal-Projects/flower-fgvc/output', 'root_dir': '/content/drive/MyDrive/Personal-Projects/flower-fgvc', 'use_gpu': False}


In [8]:
# read experiment params

exp_params = get_exp_params()
print('Experiment parameters\n')
print(exp_params)

Experiment parameters

{'transform': {'resize_dim': 256, 'crop_dim': 224}, 'train': {'batch_size': 32, 'loss': 'cross-entropy', 'epoch_interval': 1, 'num_epochs': 1}, 'model': {'name': 'alexnet', 'optimizer': 'Adam', 'lr': 0.001, 'weight_decay': 1e-07, 'amsgrad': True, 'momentum': 0.8, 'build_on_pretrained': False, 'pretrained_filename': '/models/checkpoints/last_model.pt'}, 'dataset': {'size': 'subset'}}


In [9]:
from torch.utils.data import Dataset
import os
from scipy.io import loadmat
import torch
from torchvision.io import read_image


class FlowerDataset(Dataset):

    def __init__(self, data_dir, data_filepaths, transforms):
        img_dir = os.path.join(data_dir, 'jpg')
        #seg_mask_dir = os.path.join(data_dir, 'segmim')
        labels_path = os.path.join(data_dir, 'imagelabels.mat')
        labels_mat = loadmat(labels_path)
        #id_path = os.path.join(data_dir, 'setid.mat')
        #ids = loadmat(id_path)
        #print(ids)
        self.data_filepaths = data_filepaths
        self.img_dir = img_dir
        self.labels_tensor = torch.from_numpy(labels_mat['labels'][0]).int() - 1
        self.num_classes = len(self.labels_tensor.unique())
        self.data_transform = transforms
        #print('unique ids', ids['tstid'].min(), ids['tstid'].max())

    def __len__(self):
        return len(self.data_filepaths)

    def __getitem__(self, idx):
        fn = self.data_filepaths[idx]
        si = fn.find("_")
        img_idx = int(fn[si+1:si+6])
        img_tensor = read_image(os.path.join(self.img_dir, fn)).float()
        img_tensor = self.data_transform(img_tensor)
        label = self.labels_tensor[img_idx]
        onehot_label = torch.zeros(self.num_classes, dtype=torch.float)
        onehot_label[label] = 1
        soft_label = torch.zeros(self.num_classes, dtype=torch.float)

        return {
            'img': img_tensor,
            'label': label,
            'olabel': onehot_label,
            'slabel': soft_label
        }



In [10]:
composed_transforms =  transforms.Compose([
            #transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        #std=[0.229, 0.224, 0.225]),
            transforms.Resize(exp_params['transform']['resize_dim']),
            transforms.CenterCrop(exp_params['transform']['crop_dim'])
        ])

train_fns, val_fns, test_fns, _ = get_file_paths(config_params['data_dir'])
ftr_dataset = FlowerDataset(config_params['data_dir'], train_fns, composed_transforms)
val_dataset = FlowerDataset(config_params['data_dir'], val_fns, composed_transforms)
test_dataset = FlowerDataset(config_params['data_dir'], test_fns, composed_transforms)
sm_trlen = int(0.1 * len(ftr_dataset))
sm_telen = int(0.01 * len(test_dataset))
sm_vlen = int(0.1 * len(val_dataset))

sm_ftr_dataset = torch.utils.data.Subset(ftr_dataset, list(range(sm_trlen)))
sm_val_dataset = torch.utils.data.Subset(val_dataset, list(range(sm_vlen)))
sm_test_dataset = torch.utils.data.Subset(test_dataset, list(range(sm_telen)))

print('Full train dataset length', len(ftr_dataset))
print('Subset train dataset length', sm_trlen)
print('\nFull validation dataset length', len(val_dataset))
print('Subset validation dataset length', sm_vlen)
print('\nFull test dataset length', len(test_dataset))
print('Subset test dataset length', sm_telen)

Full train dataset length 1020
Subset train dataset length 102

Full validation dataset length 1020
Subset validation dataset length 102

Full test dataset length 6149
Subset test dataset length 61


In [11]:
'''
from torch.utils.data import DataLoader
import requests
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
import matplotlib.pyplot as plt

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
'''

'''
generator = pipeline("text-generation", model="gpt2")
prompt = f"Describe the characteristics of the flower species named sunflower"
description = generator(prompt, max_length = 100, num_return_sequences=1, truncation = True)
print(description[0]['generated_text'])
print('\n\n')
'''

'''
train_loader = DataLoader(sm_ftr_dataset, batch_size = 1, shuffle = False)

# conditional image captioning
text = ""
plt.axis("off")
for bi, batch in enumerate(train_loader):
    img = batch['img'].float().to(config_params['device']) / 255.0
    print(batch['label'])
    prompt = f"Photograph of the flower pink primrose"
    np_img = img.transpose(1, 3).transpose(1, 2).numpy()
    inputs = processor(np_img[0], return_tensors="pt")
    plt.imshow(np_img[0])
    plt.show()
    out = model.generate(**inputs)
    print(processor.decode(out[0], skip_special_tokens = True))
    if bi == 0:
        break
'''

'\ntrain_loader = DataLoader(sm_ftr_dataset, batch_size = 1, shuffle = False)\n\n# conditional image captioning\ntext = ""\nplt.axis("off")\nfor bi, batch in enumerate(train_loader):\n    img = batch[\'img\'].float().to(config_params[\'device\']) / 255.0\n    print(batch[\'label\'])\n    prompt = f"Photograph of the flower pink primrose"\n    np_img = img.transpose(1, 3).transpose(1, 2).numpy()\n    inputs = processor(np_img[0], return_tensors="pt")\n    plt.imshow(np_img[0])\n    plt.show()\n    out = model.generate(**inputs)\n    print(processor.decode(out[0], skip_special_tokens = True))\n    if bi == 0:\n        break\n'

In [12]:

from models.custom_models import get_model
import torch
from common.utils import get_exp_params, get_config, save_experiment_output, save_model_helpers, save_model_chkpt
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import os
from common.loss_utils import ECCLoss

class Classification:

    def __init__(self, train_dataset, val_dataset, test_dataset):
        cfg = get_config()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.exp_params = get_exp_params()
        self.model_params = self.exp_params['model']
        self.device = cfg['device']
        self.num_classes = 102
        #self.data_transform = transforms

    def __loss_fn(self, loss_name = 'cross-entropy'):
        if loss_name == 'cross-entropy':
            return torch.nn.CrossEntropyLoss()
        elif loss_name == 'mse':
            return torch.nn.MSELoss()
        elif loss_name == 'l1':
            return torch.nn.L1Loss()
        elif loss_name == 'ecc':
            loss_fn = ECCLoss(self.num_classes, self.dim)
            return loss_fn
        else:
            raise SystemExit("Error: no valid loss function name passed! Check run.yaml")

    def __save_model_checkpoint(self, model_state, optimizer_state, chkpt_info):
        save_experiment_output(model_state, chkpt_info, False)
        save_model_helpers(optimizer_state, True)
        #os.remove(os.path.join(self.root_dir, "models/checkpoints/current_model.pt"))

    def __conduct_training(self, model, optimizer, train_loader, val_loader, tr_len, val_len):
        num_epochs = self.exp_params['train']['num_epochs']
        epoch_interval = self.exp_params['train']['epoch_interval']
        loss_fn = self.__loss_fn()
        trlosshistory, vallosshistory, valacchistory = [], [], []

        for epoch in range(num_epochs):

            model.train()
            tr_loss, val_loss, val_acc = 0.0, 0.0, 0.0

            for _, batch in enumerate(tqdm(train_loader, desc = '\t\tRunning through training set', position = 0, leave = True, disable = True)):
                optimizer.zero_grad()
                imgs = batch['img'].float().to(self.device)
                olabels = batch['olabel']
                op,feats = model(imgs)
                print('op sz', op.size(), olabels.size(), feats.size())
                loss = loss_fn(op, olabels)
                loss.backward()
                optimizer.step()
                tr_loss += (loss.item() * imgs.size(0))


            tr_loss /= tr_len
            trlosshistory.append(tr_loss)

            model.eval()

            for _, batch in enumerate(tqdm(val_loader, desc = '\t\tRunning through validation set', position = 0, leave = True, disable = True)):
                imgs = batch['img'].float().to(self.device)
                olabels = batch['olabel']
                op = model(imgs)
                loss = loss_fn(op, olabels)
                val_loss += (loss.item() * imgs.size(0))
                correct_label = batch['label']
                pred_label = torch.argmax(op, 1)
                #print('label size', correct_label.size(), pred_label.size())
                val_acc += (correct_label == pred_label).sum()

            val_loss /= val_len
            val_acc /=  val_len
            vallosshistory.append(val_loss)
            valacchistory.append(val_acc.item())

            if epoch % epoch_interval == 0:
                print(f'\tEpoch {epoch+1} Training Loss: {tr_loss}')
                print(f"\tEpoch {epoch+1} Validation Loss: {val_loss}\n")

        model_info = {
            'trlosshistory': trlosshistory,
            'vallosshistory': vallosshistory,
            'valacchistory': valacchistory,
            'last_epoch': -1
        }
        self.__save_model_checkpoint(model, optimizer, model_info)


    def run_fgvc_pipeline(self):
        print('entered pipeline')
        model_name = self.model_params['name']
        model = get_model(102, model_name)
        print('got model')
        self.dim = model.dim
        optimizer = torch.optim.Adam(model.parameters(),
            lr = self.model_params['lr'],
            weight_decay = self.model_params['weight_decay'],
            amsgrad = self.model_params['amsgrad'])
        print('got optimizer')
        batch_size = self.exp_params['train']['batch_size']

        train_loader = DataLoader(self.train_dataset, batch_size = batch_size, shuffle = False)
        val_loader = DataLoader(self.val_dataset, batch_size = batch_size, shuffle = False)
        tr_len = len(self.train_dataset)
        val_len = len(self.val_dataset)

        print('Training of classifier...\n')

        self.__conduct_training(model, optimizer, train_loader, val_loader, tr_len, val_len)

        torch.cuda.empty_cache()


In [None]:
composed_transforms =  transforms.Compose([
            #transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225]),
            transforms.Resize(exp_params['transform']['resize_dim']),
            transforms.CenterCrop(exp_params['transform']['crop_dim'])
        ])
classification = Classification(sm_ftr_dataset, sm_val_dataset, sm_test_dataset)
classification.run_fgvc_pipeline()