# Introduction 

**This is a basic CNN Model training notebook**

It is based on: 
- Thumbnail images
- Basic data transformation (using Albumentation):
    - resizing images to 512x512
    - normalizing pixel values
- CNN Architecture


**Todos:**

- Learn about Dataset & DataLoader
- add augmentations (albumentation)
- gem pooling

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        if "best_model" in filename:
            print(dirname+"/"+filename)
        continue

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_20-05-37.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_20-43-31.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_21-17-56.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_19-27-54.pth


In [2]:


import os
import gc
import cv2
import math
import copy
import time
import random
import glob
from matplotlib import pyplot as plt

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp
import torchvision

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, f1_score

import albumentations as A
from albumentations.pytorch import ToTensorV2

# For Image Models
import timm

# Albumentations for augmentations
# import albumentations as A
# from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
# warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [3]:
CONFIG = {
    "is_submission": False,
    "n_fold": 5,
    'fold': 1,
    "seed": 42,
    "img_size": 512,
    "crop_vertical":True,
    "model_name": "tf_efficientnet_b0_ns",
    "num_classes": 5,
    "valid_batch_size": 16,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    # "model_path": '/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-10-26_09-10-29.pth',
    "encoder_path": "/kaggle/input/efficientnetb0-training-crop-images/label_encoder_2023-11-06_19-27-47.pkl"
}

## 1. Data Preparation

In [4]:
ROOT_DIR = '/kaggle/input/UBC-OCEAN'
TRAIN_DIR = '/kaggle/input/UBC-OCEAN/train_thumbnails'
TEST_DIR = '/kaggle/input/UBC-OCEAN/test_thumbnails'
ALT_TEST_DIR = '/kaggle/input/UBC-OCEAN/test_images'
TMA_TRAIN_DIR = '/kaggle/input/UBC-OCEAN/train_images'


def get_train_file_path(df_train_row):
    if df_train_row.is_tma == False:
        return f"{TRAIN_DIR}/{df_train_row.image_id}_thumbnail.png"
    else:
        return f"{TMA_TRAIN_DIR}/{df_train_row.image_id}.png"



def get_test_file_path(image_id):
    if os.path.exists(f"{TEST_DIR}/{image_id}_thumbnail.png"):
        return f"{TEST_DIR}/{image_id}_thumbnail.png"
    else:
        return f"{ALT_TEST_DIR}/{image_id}.png"



In [5]:
encoder = joblib.load(CONFIG["encoder_path"])

In [6]:
df_test = pd.read_csv("/kaggle/input/UBC-OCEAN/test.csv")
df_test['file_path'] = df_test['image_id'].apply(get_test_file_path)
df_test["target_label"] = 0 
df_test

Unnamed: 0,image_id,image_width,image_height,file_path,target_label
0,41,28469,16987,/kaggle/input/UBC-OCEAN/test_thumbnails/41_thu...,0


In [7]:
class UBCDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.filenames = df.file_path.values
        self.labels =  df.target_label.values
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if CONFIG["crop_vertical"]:
            img = crop_vertical(img)
        
        # img = custom_center_crop_or_resize(img, (1024, 1024))
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            "image": img,
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
               }

def crop_vertical(image):
    """
    Function crops images if multiple slices contained and separated by black vertical background.
    """
    vertical_sum = np.sum(image, axis=(0, 2))

    # Identify the positions where the sum is zero
    zero_positions = np.where(vertical_sum == 0)[0]

    if len(zero_positions)==0:
        cropped_images = [image]
    else:
        # If the image does not start with a black area, add index 0
        if zero_positions[0] != 0:
            zero_positions = np.insert(zero_positions, 0, 0)

        # If the image does not end with a black area, add the image width
        if zero_positions[-1] != image.shape[1] - 1:
            zero_positions = np.append(zero_positions, image.shape[1] - 1)

        start_idx = zero_positions[0]
        cropped_images = []

        for idx in range(1, len(zero_positions)):
            end_idx = zero_positions[idx]
            if end_idx - start_idx > 1:  # If the width of the cropped section is greater than 1
                cropped = image[:, start_idx:end_idx]
                # only include samples which are of min size
                if cropped.shape[1]>200:  
                    cropped_images.append(cropped)
                    # cv2.imwrite(f"{save_prefix}_{idx}.jpg", cropped)
            start_idx = end_idx

    final_crops = []
    # remove black bars above/below the crops 
    for cropped in cropped_images:
        horizontal_sum = np.sum(cropped, axis=(1, 2))
        zero_positions = np.where(horizontal_sum == 0)[0]
        img_ = np.delete(cropped, zero_positions, axis=0)
        final_crops.append(img_)
    if len(final_crops)==0:
        return image
    return final_crops[0]


def custom_center_crop_or_resize(image, crop_size):
    # If both dimensions of the image are greater than or equal to the desired size, apply CenterCrop
    if image.shape[0] >= crop_size[0] and image.shape[1] >= crop_size[1]:
        return A.CenterCrop(crop_size[0], crop_size[1])(image=image)["image"]
    # Else, just resize the image to the desired size
    else:
        return A.Resize(crop_size[0], crop_size[1])(image=image)["image"]



In [8]:
data_transforms = {
    "train": A.Compose([
        A.RandomResizedCrop(CONFIG['img_size'], CONFIG['img_size'], scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.2),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        A.CoarseDropout(p=0.2),
        A.Cutout(p=0.2),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225], 
            max_pixel_value=255.0, 
            p=1.0
        ),
        ToTensorV2()], p=1.)
}



## 2. Model Creation

In [9]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'


class EfficientNetB0(nn.Module):
    '''
    EfficientNet B0 fine-tune.
    '''
    def __init__(self, model_name, num_classes, pretrained=False, checkpoint_path=None):
        '''
        Fine tune for EfficientNetB0
        Args
            n_classes : int - Number of classification categories.
            learnable_modules : tuple - Names of the modules to fine-tune.
        Return
            
        '''
        super(EfficientNetB0, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path=checkpoint_path)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.linear = nn.Linear(in_features, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, images):
        """
        Forward function for the fine-tuned model
        Args
            x: 
        Return
            result
        """
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        output = self.linear(pooled_features)
        return output



In [10]:
model_filepaths = []
models = []

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        if "best_model" in filename:
            curr_model_path = dirname+"/"+filename
            print(curr_model_path)
            model_filepaths.append(curr_model_path)
        continue

for curr_model_path in sorted(model_filepaths):
    print(curr_model_path)
    model = EfficientNetB0(CONFIG['model_name'], CONFIG['num_classes'], pretrained=False)
    model.to(CONFIG['device']);
    model.load_state_dict(torch.load(curr_model_path, map_location=torch.device('cpu')))
    models.append(model)
print(f"Numer of models loaded: {len(models)}")

  model = create_fn(


/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_20-05-37.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_20-43-31.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_21-17-56.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_19-27-54.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_19-27-54.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_20-05-37.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_20-43-31.pth
/kaggle/input/efficientnetb0-training-crop-images/best_model_checkpoint2023-11-06_21-17-56.pth
Numer of models loaded: 4


In [11]:
"""model_new = EfficientNetB0(CONFIG['model_name'], CONFIG['num_classes'], pretrained=False)
# weights = torch.load( BEST_WEIGHT )
# model.load_state_dict(weights)
model_new.to(CONFIG['device']);
model_new.load_state_dict(torch.load(CONFIG["model_path"], map_location=torch.device('cpu')))"""

'model_new = EfficientNetB0(CONFIG[\'model_name\'], CONFIG[\'num_classes\'], pretrained=False)\n# weights = torch.load( BEST_WEIGHT )\n# model.load_state_dict(weights)\nmodel_new.to(CONFIG[\'device\']);\nmodel_new.load_state_dict(torch.load(CONFIG["model_path"], map_location=torch.device(\'cpu\')))'

In [12]:
# Inspect the weights of the model
# for name, parameter in model.named_parameters():
#     print(f'{name}: {parameter.shape}')
# model.state_dict()


# Validation 

In [13]:
def predict_val_dataset(model, CONFIG, df_validate, TRAIN_DIR=None, val_size=1.0):
    if not CONFIG["is_submission"]:
        valid_dataset = UBCDataset(df_validate, transforms=data_transforms["valid"])
        valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)

        preds = []
        labels_list = []
        valid_acc = 0.0

        with torch.no_grad():
            bar = tqdm(enumerate(valid_loader), total=len(valid_loader))
            for step, data in bar: 
                # print(step)
                images = data['image'].to(CONFIG["device"], dtype=torch.float)
                labels = data['label'].to(CONFIG["device"], dtype=torch.long)

                batch_size = images.size(0)
                outputs = model(images)
                _, predicted = torch.max(model.softmax(outputs), 1)
                preds.append( predicted.detach().cpu().numpy() )
                labels_list.append( labels.detach().cpu().numpy() )
                acc = torch.sum( predicted == labels )
                valid_acc  += acc.item()
        valid_acc /= len(valid_loader.dataset)
        preds = np.concatenate(preds).flatten()
        labels_list = np.concatenate(labels_list).flatten()
        pred_labels = encoder.inverse_transform( preds )
        
        # Calculate Balanced Accuracy
        bal_acc = balanced_accuracy_score(labels_list, preds)
        # Calculate Confusion Matrix
        conf_matrix = confusion_matrix(labels_list, preds)
        macro_f1 = f1_score(labels_list, preds, average='macro')
        micro_f1 = f1_score(labels_list, preds, average='micro')
        weighted_f1 = f1_score(labels_list, preds, average='weighted')
    
        print(f"Validation Accuracy: {valid_acc}")
        print(f"Balanced Accuracy: {bal_acc}")
        print(f"Macro F1-Score: {macro_f1}")
        print(f"Micro F1-Score: {micro_f1}")
        print(f"Weighted F1-Score: {weighted_f1}")
        print(f"Confusion Matrix: {conf_matrix}")
        
        # add to validation dataframe
        df_validate["pred"] = preds
        df_validate["pred_labels"] = pred_labels
        return df_validate, preds, labels_list
    else:
        print("Skip validation on training set due to submission!")
        return None

# Model Validation

In [14]:
if not CONFIG["is_submission"]: 
    df_train = pd.read_csv("/kaggle/input/UBC-OCEAN/train.csv")
    print(df_train.shape)
    df_train['file_path'] = df_train.apply(lambda row: get_train_file_path(row), axis=1)
    df_train['target_label'] = encoder.transform(df_train['label'])
    # use stratified K Fold for crossvalidation 
    skf = StratifiedKFold(n_splits=CONFIG['n_fold'], shuffle=True, random_state=CONFIG["seed"])
    for fold, ( _, val_) in enumerate(skf.split(X=df_train, y=df_train.target_label)):
        df_train.loc[val_ , "kfold"] = int(fold)
        
    all_labels = []
    all_predictions = []
    for fold, model in enumerate(models):
        fold += 1
        print("Evaluate Fold: ", fold)
        model.eval()
        model.to(CONFIG["device"])
        df_train_fold = df_train[df_train["kfold"]!=fold].reset_index(drop=True)
        df_valid_fold = df_train[df_train["kfold"]==fold].reset_index(drop=True)

        #train_dataset = UBCDataset(df_train_fold, transforms=data_transforms["train"])
        #train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
        #                          num_workers=2, shuffle=False, pin_memory=True)
        valid_dataset = UBCDataset(df_valid_fold, transforms=data_transforms["valid"])
        valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                                  num_workers=2, shuffle=False, pin_memory=True)

        df_validate, predictions, labels = predict_val_dataset(model, CONFIG, df_valid_fold, TRAIN_DIR, val_size=1)
        all_labels.extend(labels)
        all_predictions.extend(predictions)
        display(df_validate)



(538, 5)
Evaluate Fold:  1


100%|██████████| 7/7 [00:47<00:00,  6.77s/it]

Validation Accuracy: 0.5740740740740741
Balanced Accuracy: 0.6137777777777778
Macro F1-Score: 0.5809937513466925
Micro F1-Score: 0.5740740740740741
Weighted F1-Score: 0.5689769925064043
Confusion Matrix: [[18  0  1  1  0]
 [ 5 12  5  2  1]
 [ 8 11 21  4  1]
 [ 0  1  3  5  0]
 [ 1  2  0  0  6]]





Unnamed: 0,image_id,label,image_width,image_height,is_tma,file_path,target_label,kfold,pred,pred_labels
0,1252,HGSC,60420,27480,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1252_...,2,1.0,0,CC
1,1289,HGSC,43940,26785,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1289_...,2,1.0,1,EC
2,1660,CC,83340,20447,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1660_...,0,1.0,0,CC
3,2097,HGSC,31696,21984,False,/kaggle/input/UBC-OCEAN/train_thumbnails/2097_...,2,1.0,1,EC
4,2391,LGSC,58075,26192,False,/kaggle/input/UBC-OCEAN/train_thumbnails/2391_...,3,1.0,3,LGSC
...,...,...,...,...,...,...,...,...,...,...
103,61823,EC,32807,19314,False,/kaggle/input/UBC-OCEAN/train_thumbnails/61823...,1,1.0,2,HGSC
104,61852,HGSC,33401,36180,False,/kaggle/input/UBC-OCEAN/train_thumbnails/61852...,2,1.0,2,HGSC
105,63015,CC,52483,35320,False,/kaggle/input/UBC-OCEAN/train_thumbnails/63015...,0,1.0,0,CC
106,64188,HGSC,77833,30683,False,/kaggle/input/UBC-OCEAN/train_thumbnails/64188...,2,1.0,2,HGSC


Evaluate Fold:  2


100%|██████████| 7/7 [00:47<00:00,  6.81s/it]

Validation Accuracy: 0.6296296296296297
Balanced Accuracy: 0.6351866028708134
Macro F1-Score: 0.6207759080096034
Micro F1-Score: 0.6296296296296297
Weighted F1-Score: 0.6270103721365677
Confusion Matrix: [[14  2  2  0  1]
 [ 5 12  5  0  3]
 [ 3 10 29  1  1]
 [ 2  1  3  4  0]
 [ 0  1  0  0  9]]





Unnamed: 0,image_id,label,image_width,image_height,is_tma,file_path,target_label,kfold,pred,pred_labels
0,66,LGSC,48871,48195,False,/kaggle/input/UBC-OCEAN/train_thumbnails/66_th...,3,2.0,2,HGSC
1,281,LGSC,42309,15545,False,/kaggle/input/UBC-OCEAN/train_thumbnails/281_t...,3,2.0,2,HGSC
2,286,EC,37204,30020,False,/kaggle/input/UBC-OCEAN/train_thumbnails/286_t...,1,2.0,1,EC
3,1020,HGSC,36585,33751,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1020_...,2,2.0,2,HGSC
4,1080,HGSC,31336,23200,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1080_...,2,2.0,2,HGSC
...,...,...,...,...,...,...,...,...,...,...
103,63121,EC,53469,44300,False,/kaggle/input/UBC-OCEAN/train_thumbnails/63121...,1,2.0,1,EC
104,63298,HGSC,26067,20341,False,/kaggle/input/UBC-OCEAN/train_thumbnails/63298...,2,2.0,2,HGSC
105,64629,HGSC,25480,14920,False,/kaggle/input/UBC-OCEAN/train_thumbnails/64629...,2,2.0,2,HGSC
106,65300,HGSC,75860,27503,False,/kaggle/input/UBC-OCEAN/train_thumbnails/65300...,2,2.0,2,HGSC


Evaluate Fold:  3


100%|██████████| 7/7 [00:49<00:00,  7.01s/it]

Validation Accuracy: 0.6448598130841121
Balanced Accuracy: 0.5845454545454545
Macro F1-Score: 0.5776386386873751
Micro F1-Score: 0.6448598130841121
Weighted F1-Score: 0.6522502665319181
Confusion Matrix: [[12  4  0  4  0]
 [ 2 14  4  2  2]
 [ 4  2 34  4  0]
 [ 1  1  3  3  2]
 [ 1  2  0  0  6]]





Unnamed: 0,image_id,label,image_width,image_height,is_tma,file_path,target_label,kfold,pred,pred_labels
0,4,HGSC,23785,20008,False,/kaggle/input/UBC-OCEAN/train_thumbnails/4_thu...,2,3.0,2,HGSC
1,706,HGSC,75606,25965,False,/kaggle/input/UBC-OCEAN/train_thumbnails/706_t...,2,3.0,2,HGSC
2,970,HGSC,32131,18935,False,/kaggle/input/UBC-OCEAN/train_thumbnails/970_t...,2,3.0,2,HGSC
3,1666,HGSC,69900,16083,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1666_...,2,3.0,2,HGSC
4,1774,HGSC,44231,37571,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1774_...,2,3.0,1,EC
...,...,...,...,...,...,...,...,...,...,...
102,61089,HGSC,32369,19660,False,/kaggle/input/UBC-OCEAN/train_thumbnails/61089...,2,3.0,2,HGSC
103,61961,CC,47531,45240,False,/kaggle/input/UBC-OCEAN/train_thumbnails/61961...,0,3.0,0,CC
104,63165,CC,30342,12783,False,/kaggle/input/UBC-OCEAN/train_thumbnails/63165...,0,3.0,1,EC
105,64771,EC,29163,27940,False,/kaggle/input/UBC-OCEAN/train_thumbnails/64771...,1,3.0,1,EC


Evaluate Fold:  4


100%|██████████| 7/7 [01:06<00:00,  9.54s/it]

Validation Accuracy: 0.6074766355140186
Balanced Accuracy: 0.6180808080808081
Macro F1-Score: 0.6055602744802853
Micro F1-Score: 0.6074766355140186
Weighted F1-Score: 0.6092156243544018
Confusion Matrix: [[14  4  2  0  0]
 [ 2 15  6  2  0]
 [ 8  6 25  4  1]
 [ 3  1  2  3  0]
 [ 0  1  0  0  8]]





Unnamed: 0,image_id,label,image_width,image_height,is_tma,file_path,target_label,kfold,pred,pred_labels
0,91,HGSC,3388,3388,True,/kaggle/input/UBC-OCEAN/train_images/91.png,2,4.0,1,EC
1,1952,CC,33685,38053,False,/kaggle/input/UBC-OCEAN/train_thumbnails/1952_...,0,4.0,0,CC
2,3672,LGSC,62463,21527,False,/kaggle/input/UBC-OCEAN/train_thumbnails/3672_...,3,4.0,3,LGSC
3,3997,MC,49467,29610,False,/kaggle/input/UBC-OCEAN/train_thumbnails/3997_...,4,4.0,4,MC
4,4608,EC,33155,39867,False,/kaggle/input/UBC-OCEAN/train_thumbnails/4608_...,1,4.0,1,EC
...,...,...,...,...,...,...,...,...,...,...
102,63289,CC,32380,35307,False,/kaggle/input/UBC-OCEAN/train_thumbnails/63289...,0,4.0,0,CC
103,63897,HGSC,9022,23582,False,/kaggle/input/UBC-OCEAN/train_thumbnails/63897...,2,4.0,4,MC
104,64111,HGSC,15549,8129,False,/kaggle/input/UBC-OCEAN/train_thumbnails/64111...,2,4.0,1,EC
105,64824,CC,46589,19365,False,/kaggle/input/UBC-OCEAN/train_thumbnails/64824...,0,4.0,1,EC


# Prediction on Test Data

In [15]:
# Predict on Test Dataset
test_dataset = UBCDataset(df_test, transforms=data_transforms["valid"])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['valid_batch_size'], 
                          num_workers=2, shuffle=False, pin_memory=True)

preds = []
with torch.no_grad():
    bar = tqdm(enumerate(test_loader), total=len(test_loader))
    for step, data in bar: 
        images = data['image'].to(CONFIG["device"], dtype=torch.float)
        ensemble_output = 0

        for model in models:
            model.eval()
            model.to(CONFIG["device"])
            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)
            ensemble_output += probabilities

        _, predicted = torch.max(ensemble_output, 1)
        preds.append(predicted.detach().cpu().numpy())

preds = np.concatenate(preds).flatten()
pred_labels = encoder.inverse_transform(preds)

100%|██████████| 1/1 [00:03<00:00,  3.62s/it]


In [16]:
df_sub = pd.read_csv(f"{ROOT_DIR}/sample_submission.csv")
df_sub["label"] = pred_labels
df_sub.to_csv("submission.csv", index=False)

In [17]:
df_sub

Unnamed: 0,image_id,label
0,41,HGSC
