## Resnet 50

In [None]:
import numpy as np
import pandas as pd
import os
import json

# input set up
taxo_name = "FFTT" # or "Baseline", "HFTT", FFTT"
taxo_idx = 3 # "Baseline": [0,1], "HFTT": [0,1,2,3], "FFTT": [0,1,2,3,4,5]
dataset_train = "train" 
dataset_test = "test"
dataset_extra = "scitsrcomp"
img_path_train = "td4cltabs/train"
img_path_test = "td4cltabs/test"
img_path_extra = "td4cltabs/SciTSRComp"

Baseline_mappings = ["Baseline_I", "Baseline_II"]
HFTT_mappings = ["HFTT_Novel_I", "HFTT_Novel_II", "HFTT_Novel_III", "HFTT_Novel_IV"]
FFTT_mappings = ["FFTT_Novel_I", "FFTT_Novel_II", "FFTT_Novel_III",
                "FFTT_Novel_IV", "FFTT_Novel_V", "FFTT_Novel_VI"]

if taxo_name == "Baseline":
    mappings = Baseline_mappings
    assert taxo_idx < 2
elif taxo_name == "HFTT":
    mappings = HFTT_mappings
    assert taxo_idx < 4
else:
    mappings = FFTT_mappings
    assert taxo_idx < 6

# read related files
with open(f"td4cltabs/metadata/labels_metadata.json", "r") as input_file:
    taxo_id2names = json.load(input_file)[mappings[taxo_idx]]
    if taxo_name != "FFTT":
        taxo_id2names = {int(k): v for k, v in taxo_id2names.items()}
    
train_df = pd.read_csv(f"td4cltabs/metadata/{mappings[taxo_idx]}/{dataset_train}.csv",
                        index_col=[0])
test_df = pd.read_csv(f"td4cltabs/metadata/{mappings[taxo_idx]}/{dataset_test}.csv",
                        index_col=[0])
scitsr_df = pd.read_csv(f"td4cltabs/metadata/{mappings[taxo_idx]}/{dataset_extra}.csv",
                        index_col=[0])

print("---------------------")
print("TD4DLTabs train No. of instances: {}".format(len(train_df[mappings[taxo_idx]].values)))
if taxo_name != "FFTT":
    taxo_freqs = train_df[mappings[taxo_idx]].value_counts().rename(index=taxo_id2names)
    for freq_name, freq_value in taxo_freqs.items():
        print("\tNo. of {}: {}".format(freq_name, freq_value/len(train_df)))
print("---------------------")
print("TD4DLTabs test No. of instances: {}".format(len(test_df[mappings[taxo_idx]].values)))
if taxo_name != "FFTT":
    taxo_freqs = test_df[mappings[taxo_idx]].value_counts().rename(index=taxo_id2names)
    for freq_name, freq_value in taxo_freqs.items():
        print("\tNo. of {}: {}".format(freq_name, freq_value))
print("---------------------")
print("Scitsrcomp No. of instances: {}".format(len(scitsr_df[mappings[taxo_idx]].values)))
if taxo_name != "FFTT":
    taxo_freqs = scitsr_df[mappings[taxo_idx]].value_counts().rename(index=taxo_id2names)
    for freq_name, freq_value in taxo_freqs.items():
        print("\tNo. of {}: {}".format(freq_name, freq_value))

In [None]:
import matplotlib.pyplot as plt
import pickle
import fastai
from fastai.data.all import *
from fastai.vision.all import *
from sklearn.model_selection import StratifiedKFold
from torchvision.models import resnet50, ResNet50_Weights
from collections import Counter
from pathlib import Path

torch.manual_seed(42)
kfold = 4

In [None]:
if taxo_name != "FFTT":
    f1_weighted = F1Score(average='weighted')
    f1_weighted.name = 'F1(weighted)'
    precision_weighted = Precision(average="weighted")
    recall_weighted = Recall(average="weighted")

else: 
    f1_macro = F1ScoreMulti(thresh=0.5, average='macro')
    f1_macro.name = 'F1(macro)'
    f1_weighted = F1ScoreMulti(thresh=0.5, average='weighted')
    f1_weighted.name = 'F1(weighted)'
    hamming_loss = HammingLossMulti(thresh=0.5)
    hamming_loss.name = 'HammingLoss'

test_df = pd.concat([test_df]*2)

val_pct = {}
test_pct = {}

skf = StratifiedKFold(n_splits=kfold, shuffle=True, random_state=1)
col = mappings[taxo_idx]
for fold, (train_index, val_index) in enumerate(skf.split(train_df.index, train_df[col])):
    print("Fold: ", fold)

    if taxo_name != "FFTT":
        train_taxos = DataBlock(
            blocks=(ImageBlock, CategoryBlock),
            get_x=ColReader('id', pref=img_path_train + os.path.sep),
            get_y=ColReader(col),
            splitter=IndexSplitter(val_index),
        )
        train_taxos = train_taxos.new(
            item_tfms=Resize((500, 900), method='squish', pad_mode='zeros'),
            batch_tfms=[Normalize.from_stats(*imagenet_stats)]
        )
        
        test_taxos = DataBlock(
            blocks=(ImageBlock, CategoryBlock),
            get_x=ColReader('id', pref=img_path_test + os.path.sep),
            get_y=ColReader(col),
            splitter=EndSplitter(valid_pct=0.5, valid_last=False),
        )
        test_taxos = test_taxos.new(
            item_tfms=Resize((500, 900), method='squish', pad_mode='zeros'),
            batch_tfms=[Normalize.from_stats(*imagenet_stats)]
        )

    else:
        train_taxos = DataBlock(
            blocks=(ImageBlock, MultiCategoryBlock),
            get_x=ColReader('id', pref=img_path_train + os.path.sep),
            get_y=ColReader(col, label_delim=' '),
            splitter=IndexSplitter(val_index),
        )
        train_taxos = train_taxos.new(
            item_tfms=Resize((500, 900), method='squish', pad_mode='zeros'),
            batch_tfms=[Normalize.from_stats(*imagenet_stats)]
        )
        
        test_taxos = DataBlock(
            blocks=(ImageBlock, MultiCategoryBlock),
            get_x=ColReader('id', pref=img_path_test + os.path.sep),
            get_y=ColReader(col, label_delim=' '),
            splitter=EndSplitter(valid_pct=0.5, valid_last=False),
        )
        test_taxos = test_taxos.new(
            item_tfms=Resize((500, 900), method='squish', pad_mode='zeros'),
            batch_tfms=[Normalize.from_stats(*imagenet_stats)]
        )

    dls = train_taxos.dataloaders(train_df, bs=16)
    test_dls = test_taxos.dataloaders(test_df, bs=16)
    test_dls.train = test_dls.valid

    # Create a learner
    if taxo_name != "FFTT":
        learn = vision_learner(dls, models.resnet50, 
                            loss_func=FocalLossFlat(), 
                            metrics=[error_rate, precision_weighted, recall_weighted, f1_weighted]).to_fp16()
    else:
        learn = vision_learner(dls, models.resnet50, 
                            loss_func=BCEWithLogitsLossFlat(),
                            metrics=[partial(accuracy_multi, thresh=0.5), f1_weighted, f1_macro, hamming_loss]).to_fp16()

    learn.fine_tune(epochs=15, freeze_epochs=5, cbs=[SaveModelCallback(with_opt=True, fname=f"{col}_{fold}_bestmodel"),
                    EarlyStoppingCallback(monitor='valid_loss', patience=5)])

    learn.recorder.plot_loss(skip_start=0, with_valid=True)
    plt.figure()
    plt.savefig(f'loss_plot({col}_{fold}).png')
    
    if taxo_name != "FFTT":
        learn2 = vision_learner(dls, models.resnet50, 
                            loss_func=FocalLossFlat(), 
                            metrics=[error_rate, precision_weighted, recall_weighted, f1_weighted], path="/kaggle/working/").to_fp16()
    else:
        learn2 = vision_learner(dls, models.resnet50, 
                            loss_func=BCEWithLogitsLossFlat(),
                            metrics=[partial(accuracy_multi, thresh=0.5), f1_weighted, f1_macro, hamming_loss]).to_fp16()

    learn2 = learn2.load(f"{col}_{fold}_bestmodel")

    val = learn2.validate()

    learn2.dls.valid = test_dls.valid

    test = learn2.validate()

    if col in val_pct:
        val_pct[col].append(val)
    else:
        val_pct[col] = []
        val_pct[col].append(val)
        
    if col in test_pct:
        test_pct[col].append(test)
    else:
        test_pct[col] = []
        test_pct[col].append(test)

    with open(f"{col}_val_pct.pkl", "wb") as file:
        pickle.dump(val_pct, file)
    with open(f"{col}_test_pct.pkl", "wb") as file:
        pickle.dump(test_pct, file)

In [None]:
print(val)
print(test)

In [None]:
if taxo_name != "FFTT":
    metrics = ['valid_loss', 'error_rate', 'precision_score', 'recall_score', 'F1(weighted)']
else:
    metrics = ['valid_loss', 'accuracy_multi', 'F1(weighted)', 'F1(macro)', 'HammingLoss']
    
print(f"-------{col}---------")
val = val_pct[col]
test = test_pct[col]

for idx, metric in enumerate(metrics):
    tmp = []
    for i in range(len(val)):
        tmp.append(val[i][idx])

    print(f'{metric} Validation: \tmean: {np.mean(tmp)} \tstd: {np.std(tmp)}')

    tmp = []
    for i in range(len(test)):
        tmp.append(test[i][idx])
    print(f'{metric} Test:       \tmean: {np.mean(tmp)} \tstd: {np.std(tmp)}\n')