In [None]:
import pandas as pd

import random
import numpy as np

from transformers import AutoConfig, AutoModel, AutoTokenizer, RobertaTokenizer
#from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from torch.utils.data import TensorDataset, DataLoader

from tqdm import tqdm
from datetime import datetime
from collections import Counter, defaultdict
import os
import shutil
from itertools import chain

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

import ast

In [None]:
df = pd.read_csv("/content/finetuning_data.csv")
finetune_train, finetune_valid = train_test_split(df, test_size=0.3, random_state=42)

In [None]:
train_df = pd.read_csv('/content/AMT10_train.csv', index_col=0, encoding='utf8')
valid_df = pd.read_csv('/content/AMT10_validation.csv', index_col=0, encoding='utf8')

In [None]:
AMT10 = [
    'implementation',
    'dp',
    'math',
    'greedy',
    'data structures',
    'brute force',
    'geometry',
    'constructive algorithms',
    'dfs and similar',
    'strings'
]

In [None]:
model_config = AutoConfig.from_pretrained("google/bigbird-roberta-base", max_position_embeddings=1024)
model_config

In [None]:
config = {
    'seed' : 42,
    'tags' : AMT10,
    'batchSize' : 4,
    'lr' : 5e-6,
    'trainMaxLength' : 1024,
    'testMaxLength' : 1024,
    'numEpochs' : 200,
    'model' : AutoModel.from_config(model_config),
    'tokenizer' : RobertaTokenizer.from_pretrained('roberta-base'),
    'gradient_accumulation_steps' : 4,
    'max_grad_norm' : 1.0,
    'lambda' : 10,
    'save' : True,
}

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(config['seed'])

In [None]:
new_train_idx = []  # List to store new indices
selected_train_tags = []  # List to store selected tags

# Iterate through the DataFrame indices
for index in train_df.index:
    check = 0
    t = []  # List to store selected tags for this index

    # Iterate through the tags for the current index
    for tag in ast.literal_eval(train_df.loc[index]['tags']):
        if tag in config['tags']:
            check = 1
            t.append(tag)

    # If at least one tag is in the desired tags list, append the index and selected tags
    if check == 1:
        selected_train_tags.append(t)
        new_train_idx.append(index)

print(len(new_train_idx))  # Print the length of the new index list

In [None]:
new_finetune_train_idx = []  # List to store new indices
selected_finetune_train_tags = []  # List to store selected tags

# Iterate through the DataFrame indices
for index in finetune_train.index:
    check = 0
    t = []  # List to store selected tags for this index

    # Iterate through the tags for the current index
    for tag in ast.literal_eval(finetune_train.loc[index]['tags']):
        if tag in config['tags']:
            check = 1
            t.append(tag)

    # If at least one tag is in the desired tags list, append the index and selected tags
    if check == 1:
        selected_finetune_train_tags.append(t)
        new_finetune_train_idx.append(index)

print(len(new_finetune_train_idx))  # Print the length of the new index list

In [None]:
new_valid_idx = []  # List to store new indices
selected_valid_tags = []  # List to store selected tags

# Iterate through the DataFrame indices
for index in valid_df.index:
    check = 0
    t = []  # List to store selected tags for this index

    # Iterate through the tags for the current index
    for tag in ast.literal_eval(valid_df.loc[index]['tags']):
        if tag in config['tags']:
            check = 1
            t.append(tag)

    # If at least one tag is in the desired tags list, append the index and selected tags
    if check == 1:
        selected_valid_tags.append(t)
        new_valid_idx.append(index)

print(len(new_valid_idx))  # Print the length of the new index list

In [None]:
new_finetune_valid_idx = []  # List to store new indices
selected_finetune_valid_tags = []  # List to store selected tags

# Iterate through the DataFrame indices
for index in finetune_valid.index:
    check = 0
    t = []  # List to store selected tags for this index

    # Iterate through the tags for the current index
    for tag in ast.literal_eval(finetune_valid.loc[index]['tags']):
        if tag in config['tags']:
            check = 1
            t.append(tag)

    # If at least one tag is in the desired tags list, append the index and selected tags
    if check == 1:
        selected_finetune_valid_tags.append(t)
        new_finetune_valid_idx.append(index)

print(len(new_finetune_valid_idx))  # Print the length of the new index list

In [None]:
train_df = train_df.loc[new_train_idx]
train_df['tags'] = selected_train_tags

valid_df = valid_df.loc[new_valid_idx]
valid_df['tags'] = selected_valid_tags

In [None]:
finetune_train = finetune_train.loc[new_finetune_train_idx]
finetune_train['tags'] = selected_finetune_train_tags

finetune_valid = finetune_valid.loc[new_finetune_valid_idx]
finetune_valid['tags'] = selected_finetune_valid_tags

In [None]:
X_train = finetune_train['description']
X_test = finetune_valid['description']

y_tags_train = finetune_train['tags']
y_ratings_train = finetune_train['rating'].astype(int)

y_tags_test = finetune_valid['tags']
y_ratings_test = finetune_valid['rating'].astype(int)

In [None]:
# Create an instance of the MultiLabelBinarizer
tag_label_encoder = MultiLabelBinarizer()
rating_label_encoder = LabelEncoder()

# Fit the label encoder on the labels and transform them
tag_label_encoder.fit_transform(train_df['tags'])
rating_label_encoder.fit_transform(train_df['rating'].astype(int))

y_tags_train = tag_label_encoder.transform(y_tags_train)
y_tags_test = tag_label_encoder.transform(y_tags_test)

y_ratings_train = rating_label_encoder.transform(y_ratings_train)
y_ratings_test = rating_label_encoder.transform(y_ratings_test)

In [None]:
# Define a class for multi-label classification head
class MultiLabelClassificationHead(nn.Module):
    def __init__(self, num_labels, hidden_size=768):
        super().__init__()
        self.fc = nn.Linear(hidden_size, num_labels)  # Fully connected layer
        self.sigmoid = nn.Sigmoid()  # Sigmoid activation function

    def forward(self, x):
        x = self.fc(x)  # Apply the fully connected layer
        x = self.sigmoid(x)  # Apply the sigmoid activation
        return x

# Define a class for multi-class classification head
class MultiClassClassificationHead(nn.Module):
    def __init__(self, num_labels, hidden_size=768):
        super().__init__()
        self.fc = nn.Linear(hidden_size, num_labels)  # Fully connected layer

    def forward(self, x):
        x = self.fc(x)  # Apply the fully connected layer
        return x

# Define a classifier class
class classifier(nn.Module):
    def __init__(self, model, device, tags_num_classes, ratings_num_classes, tag_state_dict, rating_state_dict):
        super().__init__()
        self.tags_num_classes = tags_num_classes  # Number of classes for tags
        self.ratings_num_classes = ratings_num_classes  # Number of classes for ratings

        # Set the device (GPU or CPU)
        self.device = device

        # Initialize multi-label and multi-class classifiers
        self.tags_classifier = MultiLabelClassificationHead(num_labels=self.tags_num_classes).to(self.device)
        self.ratings_classifier = MultiClassClassificationHead(num_labels=self.ratings_num_classes).to(self.device)

        self.tags_classifier.load_state_dict(tag_state_dict)
        self.ratings_classifier.load_state_dict(rating_state_dict)

        # Define loss functions for multi-label and multi-class classification
        self.BCE = nn.BCELoss().to(self.device)  # Binary Cross Entropy loss for multi-label classification
        self.CE = nn.CrossEntropyLoss().to(self.device)  # Cross Entropy loss for multi-class classification

        self.model = model
        self.lr = config['lr']  # Learning rate
        self.parameters = [
                {'params': self.model.parameters()},
                {'params': self.tags_classifier.parameters()},
                {'params': self.ratings_classifier.parameters()}
            ]

        # Initialize the optimizer
        self.optimizer = torch.optim.Adam(
            self.parameters,
            lr=self.lr
        )

    def forward(self, input_ids, attention_mask, tags_labels, ratings_labels):
        total_loss = 0

        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # Pooled output from the model

        # Predict tags using the tags classifier
        tags_output = self.tags_classifier(pooled_output)
        tags_loss = self.BCE(tags_output, tags_labels)  # Calculate the loss for tags

        # Predict ratings using the ratings classifier
        ratings_output = self.ratings_classifier(pooled_output)
        ratings_loss = self.CE(ratings_output, ratings_labels)  # Calculate the loss for ratings

        # Calculate the total loss using a sum of tags and ratings loss
        total_loss = tags_loss * config['lambda'] + ratings_loss

        return total_loss, tags_output, ratings_output

In [None]:
def tokenizing(tokenizer, data, max_length):
    # Tokenize and encode the text input
    data = list(data.values)
    tokenized_data = tokenizer(data, padding=True, truncation=True, return_tensors='pt', max_length=max_length)

    return tokenized_data

In [None]:
def convert_to_tensor(data, dtype):
    # Convert data to tensors
    tensor_data = torch.tensor(data, dtype=dtype)
    return tensor_data

In [None]:
from collections import Counter, defaultdict
import os
import shutil
from itertools import chain

class Trainer():
    def __init__(self,
                 model,
                 tokenized_inputs_train,
                 tokenized_inputs_test,
                 tags_labels_train,
                 tags_labels_test,
                 ratings_labels_train,
                 ratings_labels_test,
                 tag_state_dict,
                 rating_state_dict
                ):

        # Set the device (GPU or CPU)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Store the input data and labels
        self.tokenized_inputs_train = tokenized_inputs_train
        self.tokenized_inputs_test = tokenized_inputs_test

        self.tags_labels_train = tags_labels_train
        self.tags_labels_test = tags_labels_test

        self.ratings_labels_train = ratings_labels_train
        self.ratings_labels_test = ratings_labels_test

        # Determine the number of classes for tags and ratings
        self.tags_num_classes = 10
        self.ratings_num_classes = 28

        # Move the model to the specified device
        self.model = model.to(self.device)

        # Define Classifier Instance
        self.classifier_instance = classifier(self.model, self.device, self.tags_num_classes, self.ratings_num_classes, tag_state_dict, rating_state_dict)

        # Retrieve configuration parameters
        self.batch_size = config['batchSize']
        self.num_epochs = config['numEpochs']

        self.accumulation_steps = config['gradient_accumulation_steps']
        self.max_grad_norm = config['max_grad_norm']

        self.tag_classes = tag_label_encoder.classes_

        self.save = config['save']

        # Initialize input data variables
        self.input_ids_train = self.tokenized_inputs_train['input_ids']
        self.attention_mask_train = self.tokenized_inputs_train['attention_mask']

        self.input_ids_test = tokenized_inputs_test['input_ids']
        self.attention_mask_test = tokenized_inputs_test['attention_mask']

    def train(self):
        input_ids_train = self.input_ids_train
        attention_mask_train = self.attention_mask_train
        tags_labels_train = self.tags_labels_train
        ratings_labels_train = self.ratings_labels_train

        input_ids_test = self.input_ids_test
        attention_mask_test = self.attention_mask_test
        tags_labels_test = self.tags_labels_test
        ratings_labels_test = self.ratings_labels_test

        # Set the optimizer and learning rate
        optimizer = self.classifier_instance.optimizer
        parameters = self.classifier_instance.parameters

        # Set the batch size
        batch_size = self.batch_size

        # Create a DataLoader for batching the data
        train_dataset = TensorDataset(input_ids_train, attention_mask_train, tags_labels_train, ratings_labels_train)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8, pin_memory=True)

        valid_dataset = TensorDataset(input_ids_test, attention_mask_test, tags_labels_test, ratings_labels_test)
        valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

        # Set the number of training epochs#
        num_epochs = self.num_epochs
        device = self.device

        model = self.model
        classifier_instance = self.classifier_instance

        # Training loop
        min_loss = 999999
        rating_f1_s = 0
        total_f1_s = 0
        count = 0

        #epochs

        max_total_f1_macro_score_epochs = 0

        max_tag_acc_epochs = 0
        max_tag_f1_macro_epochs = 0
        max_tag_f1_micro_epochs = 0
        max_tag_f1_weighted_epochs= 0
        max_tag_f1_samples_epochs= 0
        max_tag_roc_auc_score_epochs = 0

        max_rating_acc_epochs = 0
        max_rating_f1_macro_epochs = 0
        max_rating_f1_micro_epochs = 0
        max_rating_f1_weighted_epochs= 0

        #score
        max_total_f1_macro_score = 0

        max_tag_acc = 0
        max_tag_f1_macro = 0
        max_tag_f1_micro = 0
        max_tag_f1_weighted = 0
        max_tag_f1_samples = 0
        max_tag_roc_auc_score = 0

        max_rating_acc = 0
        max_rating_f1_macro = 0
        max_rating_f1_micro = 0
        max_rating_f1_weighted = 0

        thresholds = [0.001] + [i * 0.01 for i in range(1, 101)]

        for epoch in range(num_epochs):
            # set early stopping
            #if count > 8:
            #    break
            train_loss = 0.0
            valid_loss = 0.0

            tags_true = []
            tags_pred = defaultdict(list)
            tags_pred_proba = []

            ratings_true = []
            ratings_pred = []

            count += 1

            # Training
            classifier_instance.train()
            # Zero the gradients
            optimizer.zero_grad()
            for batch in tqdm(train_dataloader):
                # Unpack the batch
                input_ids, attention_mask, tags_labels, ratings_labels = batch

                # Move the inputs and labels to the chosen device
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)
                tags_labels = tags_labels.to(device)
                ratings_labels = ratings_labels.to(device)

                # Forward pass
                loss, _, _ = classifier_instance(input_ids, attention_mask, tags_labels, ratings_labels)
                loss /= self.accumulation_steps

                # Backward pass and optimization
                loss.backward()

                if epoch % self.accumulation_steps ==  0 or epoch == batch_size - 1 or self.accumulation_steps == 0:
                    if self.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(chain(
                            model.parameters(),
                            classifier_instance.tags_classifier.parameters(),
                            classifier_instance.ratings_classifier.parameters()
                        ), self.max_grad_norm)
                    optimizer.step()
                    optimizer.zero_grad()

                train_loss += loss.item()

            # Validation
            classifier_instance.eval()
            with torch.no_grad():
                for batch in tqdm(valid_dataloader):
                    # Unpack the batch
                    input_ids, attention_mask, tags_labels, ratings_labels = batch

                    # Move the inputs and labels to the chosen device
                    input_ids = input_ids.to(device)
                    attention_mask = attention_mask.to(device)
                    tags_labels = tags_labels.to(device)
                    ratings_labels = ratings_labels.to(device)

                    # Forward pass
                    loss, tags_output, ratings_output = classifier_instance(input_ids, attention_mask, tags_labels, ratings_labels)

                    valid_loss += loss.item()

                    # tags
                    tags_pred_proba.extend(tags_output.detach().cpu().clone().tolist())


                    # Extract indices where the value is above the threshold.
                    for threshold in thresholds:
                        tags_pred[threshold].extend([(row >= threshold).nonzero().flatten().tolist() for row in tags_output.detach().cpu().clone()])

                    tags_true.extend([torch.nonzero(row).flatten().tolist() for row in tags_labels.detach().cpu().clone()])

                    ratings_pred.extend(torch.argmax(ratings_output, dim=1).detach().cpu().clone())
                    ratings_true.extend(ratings_labels.detach().cpu().clone())


            # Calculate average loss
            train_loss /= len(train_dataset)
            valid_loss /= len(valid_dataset)


            if epoch % self.accumulation_steps ==  0 or epoch == batch_size - 1 or self.accumulation_steps == 0:

                # Print the loss, F1 score, precision, and recall for monitoring
                print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}")

                tag_true = []
                #tag_pred = []

                rating_true = [tensor.detach().cpu().clone().item() for tensor in ratings_true]
                rating_pred = [tensor.detach().cpu().clone().item() for tensor in ratings_pred]

                rating_k = defaultdict(list)

                for k in [0, 1, 2]:
                    revise_rating_pred = []

                    for i in range(len(rating_pred)):
                        if abs(rating_true[i] - rating_pred[i]) <= k:
                            revise_rating_pred.append(rating_true[i])
                        else:
                            revise_rating_pred.append(rating_pred[i])
                    rating_k[k] = revise_rating_pred

                rating_pred = rating_k[1]

                rat_t = Counter(rating_true)
                rat_p = Counter(rating_pred)

                for index_list in tags_true:
                    result_true = [0] * self.tags_num_classes  # Create a list of length num_classes.
                    for index in index_list:
                        result_true[index] = 1  # Fill the corresponding index with 1.

                    tag_true.append(result_true)


                epoch_max_tag_acc = 0
                epoch_max_tag_f1_macro = 0
                epoch_max_tag_f1_micro = 0
                epoch_max_tag_f1_weighted = 0
                epoch_max_tag_f1_samples = 0

                epoch_max_rating_acc = 0
                epoch_max_rating_f1_macro = 0
                epoch_max_rating_f1_micro = 0
                epoch_max_rating_f1_weighted = 0

                epoch_max_total_f1_macro_score = 0

                epoch_max_tag_roc_auc_score = roc_auc_score(tag_true, tags_pred_proba)
                tag_true = np.array(tag_true)
                tags_pred_proba = np.array(tags_pred_proba)

                for threshold in thresholds:
                    tag_pred = []
                    for index_list in tags_pred[threshold]:
                        result_pred = [0] * self.tags_num_classes  # Create a list of length num_classes.
                        for index in index_list:
                            result_pred[index] = 1  # Fill the corresponding index with 1.

                        tag_pred.append(result_pred)

                    # tag

                    tag_acc = accuracy_score(tag_true, tag_pred)
                    tag_f1_macro = f1_score(tag_true, tag_pred, average='macro', zero_division=0)
                    tag_f1_micro = f1_score(tag_true, tag_pred, average='micro', zero_division=0)
                    tag_f1_weighted = f1_score(tag_true, tag_pred, average='weighted', zero_division=0)
                    tag_f1_samples = f1_score(tag_true, tag_pred, average='samples', zero_division=0)

                    epoch_max_tag_acc = max(epoch_max_tag_acc, tag_acc)
                    epoch_max_tag_f1_macro = max(epoch_max_tag_f1_macro, tag_f1_macro)
                    epoch_max_tag_f1_micro = max(epoch_max_tag_f1_micro, tag_f1_micro)
                    epoch_max_tag_f1_weighted = max(epoch_max_tag_f1_weighted, tag_f1_weighted)
                    epoch_max_tag_f1_samples = max(epoch_max_tag_f1_samples, tag_f1_samples)

                #tag
                print("tag acc Max Score in this epoch:", epoch_max_tag_acc)
                print("tag valid Max F1 Score(macro) per class in this epoch:", epoch_max_tag_f1_macro)
                print("tag valid Max F1 Score(micro) per class in this epoch:", epoch_max_tag_f1_micro)
                print("tag valid Max F1 Score(weighted) per class in this epoch:", epoch_max_tag_f1_weighted)
                print("tag valid Max F1 Score(samples) per class in this epoch:", epoch_max_tag_f1_samples)
                print()
                print("tag valid Max roc_auc_score avg in this epoch:", epoch_max_tag_roc_auc_score)
                for num_classes in range(self.tags_num_classes):
                    score = roc_auc_score(tag_true[:, num_classes], tags_pred_proba[:, num_classes])
                    print(f"{self.tag_classes[num_classes]} : {score}")
                print()

                # rating

                for k in [0, 1, 2]:
                    rating_pred = rating_k[k]

                    rating_acc = accuracy_score(rating_true, rating_pred)
                    rating_f1_macro = f1_score(rating_true, rating_pred, average='macro', zero_division=0)
                    rating_f1_micro = f1_score(rating_true, rating_pred, average='micro', zero_division=0)
                    rating_f1_weighted = f1_score(rating_true, rating_pred, average='weighted', zero_division=0)

                    if k == 1:
                        epoch_max_rating_acc = max(epoch_max_rating_acc, rating_acc)
                        epoch_max_rating_f1_macro = max(epoch_max_rating_f1_macro, rating_f1_macro)
                        epoch_max_rating_f1_micro = max(epoch_max_rating_f1_micro, rating_f1_micro)
                        epoch_max_rating_f1_weighted = max(epoch_max_rating_f1_weighted, rating_f1_weighted)

                    #rating
                    print(f"rating acc Max Score in this epoch at {k}:", rating_acc)
                    print(f"rating valid Max F1 Score(macro) per class in this epoch at {k}:", rating_f1_macro)
                    print(f"rating valid Max F1 Score(micro) per class in this epoch at {k}:", rating_f1_micro)
                    print(f"rating valid Max F1 Score(weighted) per class in this epoch at {k}:", rating_f1_weighted)
                    print()

                epoch_max_total_f1_macro_score = (epoch_max_tag_f1_macro + epoch_max_rating_f1_macro) / 2

                #rating
                print("rating acc Max Score in this epoch:", epoch_max_rating_acc)
                print("rating valid Max F1 Score(macro) per class in this epoch:", epoch_max_rating_f1_macro)
                print("rating valid Max F1 Score(micro) per class in this epoch:", epoch_max_rating_f1_micro)
                print("rating valid Max F1 Score(weighted) per class in this epoch:", epoch_max_rating_f1_weighted)
                print()
                print('rating_true : ', sorted(rat_t.items(), key=lambda x: x[0]))
                print('rating_pred : ', sorted(rat_p.items(), key=lambda x: x[0]))
                print()

                print(f"epoch_max_total_f1_score : {epoch_max_total_f1_macro_score}")
                print()

                #tag
                print(f"tag acc Max Score: {max_tag_acc} at {max_tag_acc_epochs}epochs")
                print(f"tag valid Max F1 Score(macro) per class: {max_tag_f1_macro} at {max_tag_f1_macro_epochs}epochs")
                print(f"tag valid Max F1 Score(micro) per class: {max_tag_f1_micro} at {max_tag_f1_micro_epochs}epochs")
                print(f"tag valid Max F1 Score(weighted) per class: {max_tag_f1_weighted} at {max_tag_f1_weighted_epochs}epochs")
                print(f"tag valid Max F1 Score(samples) per class: {max_tag_f1_samples} at {max_tag_f1_samples_epochs}epochs")
                print(f"tag valid Max roc_auc_score: {max_tag_roc_auc_score} at {max_tag_roc_auc_score_epochs}epochs")
                print()

                #rating
                print(f"rating acc Max Score: {max_rating_acc} at {max_rating_acc_epochs}epochs")
                print(f"rating valid Max F1 Score(macro) per class: {max_rating_f1_macro} at {max_rating_f1_macro_epochs}epochs")
                print(f"rating valid Max F1 Score(micro) per class: {max_rating_f1_micro} at {max_rating_f1_micro_epochs}epochs")
                print(f"rating valid Max F1 Score(weighted) per class: {max_rating_f1_weighted} at {max_rating_f1_weighted_epochs}epochs")
                print()

                print(f"prev_max_total_f1_macro_score : {max_total_f1_macro_score}")

                # tag

                if epoch_max_tag_acc > max_tag_acc:
                    max_tag_acc_epochs = epoch
                    max_tag_acc = max(epoch_max_tag_acc, max_tag_acc)

                if epoch_max_tag_f1_macro > max_tag_f1_macro:
                    max_tag_f1_macro_epochs = epoch
                    max_tag_f1_macro = max(epoch_max_tag_f1_macro, max_tag_f1_macro)

                    now = datetime.now()
                    task = 'total'
                    if self.save:
                        self.save_checkpoint(task, model, epoch)
                    count = 0
                    print('Best Model Saved !')
                    print()

                if epoch_max_tag_f1_micro > max_tag_f1_micro:
                    max_tag_f1_micro_epochs = epoch
                    max_tag_f1_micro = max(epoch_max_tag_f1_micro, max_tag_f1_micro)

                if epoch_max_tag_f1_weighted > max_tag_f1_weighted:
                    max_tag_f1_weighted_epochs = epoch
                    max_tag_f1_weighted = max(epoch_max_tag_f1_weighted, max_tag_f1_weighted)

                if epoch_max_tag_f1_samples > max_tag_f1_samples:
                    max_tag_f1_samples_epochs = epoch
                    max_tag_f1_samples = max(epoch_max_tag_f1_samples, max_tag_f1_samples)

                if epoch_max_tag_roc_auc_score > max_tag_roc_auc_score:
                    max_tag_roc_auc_score_epochs = epoch
                    max_tag_roc_auc_score = max(epoch_max_tag_roc_auc_score, max_tag_roc_auc_score)

                # rating

                if epoch_max_rating_acc > max_rating_acc:
                    max_rating_acc_epochs = epoch
                    max_rating_acc = max(epoch_max_rating_acc, max_rating_acc)

                if epoch_max_rating_f1_macro > max_rating_f1_macro:
                    max_rating_f1_macro_epochs = epoch
                    max_rating_f1_macro = max(epoch_max_rating_f1_macro, max_rating_f1_macro)

                if epoch_max_rating_f1_micro > max_rating_f1_micro:
                    max_rating_f1_micro_epochs = epoch
                    max_rating_f1_micro = max(epoch_max_rating_f1_micro, max_rating_f1_micro)

                if epoch_max_rating_f1_weighted > max_rating_f1_weighted:
                    max_rating_f1_weighted_epochs = epoch
                    max_rating_f1_weighted = max(epoch_max_rating_f1_weighted, max_rating_f1_weighted)

                # total

                if epoch_max_total_f1_macro_score > max_total_f1_macro_score:
                    max_total_f1_macro_score_epochs = epoch
                    max_total_f1_macro_score = max(epoch_max_total_f1_macro_score, max_total_f1_macro_score)

                print('----------------------------------------------------------------------------')
                print()

    def save_checkpoint(self, task, model, epoch, max_checkpoints=5):
        now = datetime.now()
        today = now.strftime('%Y-%m-%d')
        checkpoint_filename = f"{now.strftime('%Y-%m-%d')}_{epoch + 1}"
        checkpoint_path = os.path.join(f"/content/models/{task}/{today}", checkpoint_filename)

        # If the directory does not exist, create it.
        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        # Save the model state_dict
        torch.save(self.classifier_instance.state_dict(), os.path.join(checkpoint_path, f"model.pt"))
        checkpoint_files = sorted(os.listdir(f"/content/models/{task}/{today}"))
        # Delete oldest checkpoint if there are too many
        while len(checkpoint_files) > max_checkpoints + 1:
            checkpoint_files = sorted(os.listdir(f"/content/models/{task}/{today}"))
            oldest_checkpoint = os.path.join(f"/content/models/{task}/{today}", checkpoint_files[0])
            #os.remove(oldest_checkpoint)
            if os.path.exists(oldest_checkpoint) and os.path.isdir(oldest_checkpoint):
                # Check if the directory is empty, and if not, use shutil.rmtree() to recursively delete it.
                try:
                    shutil.rmtree(oldest_checkpoint)
                except Exception as e:
                    print(f"Error while deleting directory: {e}")

In [None]:
state = torch.load('/content/model.pt') # "Please input the path to the saved model."

In [None]:
model_state_dict = {}
tag_state_dict = {}
rating_state_dict = {}

In [None]:
for k, v in state.items():
    if "model." in k:
        name = k[6:]
        model_state_dict[name] = v
    if "tags_classifier." in k:
        name = k[len("tags_classifier."):]
        tag_state_dict[name] = v
    if "ratings_classifier." in k:
        name = k[len("ratings_classifier."):]
        rating_state_dict[name] = v

In [None]:
model = config['model']

In [None]:
model.load_state_dict(model_state_dict)
print('fin')

In [None]:
tokenizer = config['tokenizer']

In [None]:
tokenized_inputs_train = tokenizing(tokenizer, X_train, config['trainMaxLength'])
tokenized_inputs_test = tokenizing(tokenizer, X_test, config['testMaxLength'])

In [None]:
tags_labels_train = convert_to_tensor(y_tags_train, dtype=torch.float)
tags_labels_test = convert_to_tensor(y_tags_test, dtype=torch.float)
ratings_labels_train = convert_to_tensor(y_ratings_train, dtype=torch.long)
ratings_labels_test = convert_to_tensor(y_ratings_test, dtype=torch.long)

In [None]:
trainer = Trainer(model,
                 tokenized_inputs_train,
                 tokenized_inputs_test,
                 tags_labels_train,
                 tags_labels_test,
                 ratings_labels_train,
                 ratings_labels_test,
                 tag_state_dict,
                 rating_state_dict
                 )

In [None]:
trainer.train()

In [None]:
!zip -r /content/models.zip /content/models