# Data Imports

In [1]:
# let's start with the data and see how it goes
import os
import pandas as pd
HOME = os.getcwd()
train_csv = os.path.join(HOME, 'data', 'train.csv')
test_csv = os.path.join(HOME, 'data', 'test.csv')

df_train = pd.read_csv(train_csv)
df_test = pd.read_csv(test_csv)
# set the columns names to lower case 

df_train.columns = [c.lower() for c in df_train.columns]
df_test.columns = [c.lower() for c in df_test.columns]

# remove unnecessary columns
df_train.drop(columns=['helpfulness', 'score'], inplace=True)
df_test.drop(columns=['helpfulness', 'score'], inplace=True)

In [2]:
# add a small piece of code to call the pytorch_modular code
from pathlib import Path
import sys

current = HOME
while 'src' not in os.listdir(current):
    current = Path(current).parent

sys.path.append(str(current))
sys.path.append(os.path.join(current, 'src'))

In [3]:
df_train.head()

Unnamed: 0,title,text,category
0,Golden Valley Natural Buffalo Jerky,The description and photo on this product need...,grocery gourmet food
1,Westing Game,This was a great book!!!! It is well thought t...,toys games
2,Westing Game,"I am a first year teacher, teaching 5th grade....",toys games
3,Westing Game,I got the book at my bookfair at school lookin...,toys games
4,I SPY A is For Jigsaw Puzzle 63pc,Hi! I'm Martine Redman and I created this puzz...,toys games


In [4]:
df_test.head()

Unnamed: 0,id,title,text
0,0,PetSafe Staywell Pet Door with Clear Hard Flap,We've only had it installed about 2 weeks. So ...
1,1,"Kaytee Timothy Cubes, 1-Pound",My bunny had a hard time eating this because t...
2,2,Body Back Buddy,would never in a million years have guessed th...
3,3,SnackMasters California Style Turkey Jerky,"Being the jerky fanatic I am, snackmasters han..."
4,4,Premier Busy Buddy Tug-a-Jug Treat Dispensing ...,Wondered how quick my dog would catch on to th...


In [5]:
import nltk 
from nltk.tokenize import TweetTokenizer
from nltk.corpus import stopwords

try:
    STOP_WORDS = list(stopwords.words('english'))
except LookupError:
    nltk.download('stopwords')
    STOP_WORDS = list(stopwords.words('english'))

In [6]:
# preprocessing functions
import re
from typing import List

def to_lower(text: str) -> str:
    return text.lower()

def no_extra_spaces(text: str) -> str:
    return re.sub('\s+', ' ', text)

def no_extra_chars(text: str) -> str:
    return re.sub(r'[^a-zA-Z\s,!.;:-]+', ' ', text) 

text = 'aaa5531--==-||"z2::,.a'

def remove_stop_words(text: str,
                      tokenizer: TweetTokenizer = None) -> str:
    text = to_lower(text)    
    tokenizer = TweetTokenizer() if tokenizer is None else tokenizer
    tokens = tokenizer.tokenize(text)
    # if the remove_stop_words argument is set to True, then filter stop words
    tokens = [t.strip() for t in tokens if t not in STOP_WORDS] 
    return " ".join(tokens)

def process(text: str) -> str:
    # first lower, remove extrac chracters
    text1 = to_lower(no_extra_chars(text))
    # remove redundant words
    text2 = remove_stop_words(text1)
    # remove extra spaces
    return no_extra_spaces(text2)

import random
random.seed(69)
example = df_train['text'][int(random.random() * len(df_train))]
print(example)
print(process(example))

# # drop the 'text' column as only the title will be used for classification
df_train.drop(columns=['text'], inplace=True)
df_test.drop(columns=['text'], inplace=True)

# 16 rows have missing values in the 'title' column, remove them
df_train.dropna(inplace=True)

See the title of this review. Fortunately, I am a packrat, and kept a bunch of hole repair kits from various blow up things that we have gone through over the years. Does not come with a hole repair kit though, just to warn you. Anyway, it is back in black and bouncing our 3 year old all over the place. Indoor only, I would say. Very highly recommended, in spite of a hole within a week of use. Hope that this is the first and last one... probably not.
see title review . fortunately , packrat , kept bunch hole repair kits various blow things gone years . come hole repair kit though , warn . anyway , back black bouncing year old place . indoor , would say . highly recommended , spite hole within week use . hope first last one ... probably .


In [7]:
print(df_train.isna().sum())
print("#" * 100)
print(df_test.isna().sum())

title       0
category    0
dtype: int64
####################################################################################################
id       0
title    5
dtype: int64


In [8]:
import numpy as np

cat2idx = {
    'toys games': 0,
    'health personal care': 1,
    'beauty': 2,
    'baby products': 3,
    'pet supplies': 4,
    'grocery gourmet food': 5,
}

idx2cat = {
    0:'toys games',
    1:'health personal care',
    2:'beauty',
    3:'baby products',
    4:'pet supplies',
    5:'grocery gourmet food' 
}

# making sure the dataframes are ready for training
def df_process_data(row):
    if isinstance(row['title'], float):
        row['title'] = random.choice(list(cat2idx.keys()))
        return row
    row['title'] = process(row['title'])
    return row

def df_process_labels(row):
    row['category'] = process(row['category'])
    # map it to an integer
    row['category'] = cat2idx[row['category']]
    return row

# process the fields
df_train = df_train.apply(df_process_data, axis=1)
# process the labels
df_train = df_train.apply(df_process_labels, axis=1)
# process the data is the test split
df_test = df_test.apply(df_process_data, axis=1)

# Embeddings

In [10]:
# in the rest of the code I will be using the d
import torch
from transformers import AutoModel, AutoTokenizer
NOTEBOOK_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CHECKPOINT = 'distilbert-base-uncased' # let's keep it simple as for the first iteration
MODEL = AutoModel.from_pretrained(CHECKPOINT).to(NOTEBOOK_DEVICE)
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)

In [19]:
import nlpaug.augmenter.word as naw
import random
random.seed(69)
AUGMENTER1 = naw.ContextualWordEmbsAug(model_path=CHECKPOINT, aug_p=0.1, device=NOTEBOOK_DEVICE)
AUGMENTER2 = naw.ContextualWordEmbsAug(model_path=CHECKPOINT, aug_p=0.1, device=NOTEBOOK_DEVICE)

def augmented_df(df) -> pd.DataFrame:
    aug_d1 = AUGMENTER1.augment(df['title'].tolist())
    aug_d2 = AUGMENTER2.augment(df['title'].tolist())
    df1 = pd.DataFrame(data={"title": aug_d1, "category": df['category']})
    df2 = pd.DataFrame(data={"title": aug_d2, "category": df['category']})
    return pd.concat([df1, df2])

In [22]:
df_train_aug = augmented_df(df_train)

In [23]:
df_train = pd.concat([df_train_aug, df_train])

In [24]:
from sklearn.model_selection import train_test_split
train_data, val_data = train_test_split(df_train, test_size=0.15, stratify=df_train['category'], random_state=69)

# Train Loaders

In [25]:
from torch.utils.data import DataLoader, Dataset


def collate_function(batch: List[str], aug_prob: float = None):
    # batch will represent a list of tuples (text, category) 
    x, y = [list(row) for row in zip(*batch)]
    # convert both labels and data to tensors    
    y_tensor = torch.FloatTensor(y).to(device=NOTEBOOK_DEVICE)
    embeddings = MODEL(**TOKENIZER(x, padding=True, return_tensors='pt').to(NOTEBOOK_DEVICE)).last_hidden_state # make sure to return tensors
    return embeddings.to(NOTEBOOK_DEVICE), y_tensor

# let's create a dataset object really quick:
class LabeledReviewDS(Dataset):
    def __init__(self, data: pd.DataFrame) -> None:
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index) -> tuple[str, int]:
        return tuple(self.data.iloc[index, :2])

# let's set the random seed

torch.manual_seed(69)

train_ds = LabeledReviewDS(train_data)
val_ds = LabeledReviewDS(val_data)

# create the dataloaders
train_dl = DataLoader(dataset=train_ds, batch_size=64, shuffle=True, collate_fn=lambda x: collate_function(x, aug_prob=0.15), drop_last=True)
val_dl = DataLoader(dataset=val_ds, batch_size=64, shuffle=False, collate_fn=collate_function)

# Train A model

In [26]:
from torch import nn
class SeqClassModel(nn.Module):
    def __init__(self, 
                in_features: int,
                hidden_size: int, 
                num_classes: int, 
                num_layers: int = 2, 
                dropout: float=0.3, 
                *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.output_units = num_classes if num_classes > 2 else 1
        self.rnn = nn.LSTM(input_size=in_features, 
                           hidden_size=hidden_size, 
                           dropout=dropout, 
                           num_layers=num_layers,
                           bidirectional=True, # bidiretional RNN are more powerful
                           batch_first=True # easier manipulation
                           )
        # 2: comes from the fact that the lstm is bidirectional, the rest is similar to the LSTM documention Pytorch
        linear_input_dim = 2 * num_layers * hidden_size 
        self.batch_layer= nn.BatchNorm1d(num_features=linear_input_dim)
        # self.relu_layer = nn.LeakyReLU()
        self.head = nn.Linear(in_features=linear_input_dim, out_features=self.output_units)
        
    def forward(self, x: torch.Tensor):
        # first pass it through the rnn
        _, (hidden_state, _) = self.rnn(x)
        batch_size = hidden_state.shape[1]
        # first permuting channels: batch_size as dimensions '0' 
        # only only the last lstm layer
        hidden_state = hidden_state.permute((1, 0, 2)).reshape((batch_size, -1))
        return self.head.forward(self.batch_layer(hidden_state))


In [31]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torchmetrics.classification import MulticlassF1Score, MulticlassAccuracy

base_model = SeqClassModel(in_features=768, hidden_size=128, num_classes=6)
optimizer = AdamW(base_model.parameters(), lr=0.01)
scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.005, total_iters=100)

accuracy_metric, f1_metric = MulticlassAccuracy(num_classes=6), MulticlassF1Score(num_classes=6)

metrics = {'accuracy': accuracy_metric, 'f1_score': f1_metric}


train_configuration = {'optimizer': optimizer,
                        'scheduler': scheduler,
                        'min_val_loss': 10 ** -4,
                        'max_epochs': 60,
                        'report_epoch': 1,
                        'device': NOTEBOOK_DEVICE, 
                        'metrics': metrics,
                        'no_improve_stop': 15
                        }

In [32]:
import src.pytorch_modular.image_classification.engine_classification as cls
results = cls.train_model(base_model, train_dl, val_dl, train_configuration,    
                            log_dir=os.path.join(HOME, 'runs'),         
                            save_path=os.path.join(HOME, 'saved_models'))   

[INFO] Created SummaryWriter, saving to: /home/ayhem18/DEV/My_Kaggle_Repo/amazon_reviews/runs/experience_27...


  0%|          | 0/60 [00:00<?, ?it/s]

  2%|▏         | 1/60 [02:43<2:41:10, 163.90s/it]

#########################
training loss: 0.624814041431266
train_accuracy: 0.7833075523376465
train_f1_score: 0.7734293341636658
validation loss : 0.5441950296467923
val_accuracy: 0.8055578470230103
val_f1_score: 0.805313229560852
#########################


  3%|▎         | 2/60 [05:28<2:38:38, 164.12s/it]

#########################
training loss: 0.4654958386392767
train_accuracy: 0.8436644673347473
train_f1_score: 0.8346681594848633
validation loss : 0.45899075491631286
val_accuracy: 0.8423064351081848
val_f1_score: 0.8373845815658569
#########################


  5%|▌         | 3/60 [08:11<2:35:32, 163.73s/it]

#########################
training loss: 0.4141918997849149
train_accuracy: 0.8616425395011902
train_f1_score: 0.8530635237693787
validation loss : 0.4326708790786723
val_accuracy: 0.8575281500816345
val_f1_score: 0.8460569977760315
#########################


  7%|▋         | 4/60 [10:55<2:32:48, 163.73s/it]

#########################
training loss: 0.3890629930647379
train_accuracy: 0.8697640895843506
train_f1_score: 0.861338198184967
validation loss : 0.4150596187047079
val_accuracy: 0.8669731616973877
val_f1_score: 0.853295624256134
#########################


  8%|▊         | 5/60 [13:38<2:29:53, 163.51s/it]

#########################
training loss: 0.37274066949986995
train_accuracy: 0.8764039278030396
train_f1_score: 0.8678074479103088
validation loss : 0.3975372667430986
val_accuracy: 0.8630065321922302
val_f1_score: 0.8593823909759521
#########################


 10%|█         | 6/60 [16:21<2:27:02, 163.38s/it]

#########################
training loss: 0.35179631757665114
train_accuracy: 0.882953405380249
train_f1_score: 0.8748288750648499
validation loss : 0.45040311458262994
val_accuracy: 0.8728832006454468
val_f1_score: 0.8379121422767639
#########################


 12%|█▏        | 7/60 [19:04<2:24:15, 163.31s/it]

#########################
training loss: 0.33602279779657357
train_accuracy: 0.8890826106071472
train_f1_score: 0.8805086016654968
validation loss : 0.4116989714456788
val_accuracy: 0.8818569779396057
val_f1_score: 0.8572685718536377
#########################


 13%|█▎        | 8/60 [21:47<2:21:23, 163.15s/it]

#########################
training loss: 0.32734451732935577
train_accuracy: 0.8905859589576721
train_f1_score: 0.8816654682159424
validation loss : 0.3574498515074135
val_accuracy: 0.8923094272613525
val_f1_score: 0.876330554485321
#########################


 15%|█▌        | 9/60 [24:30<2:18:43, 163.20s/it]

#########################
training loss: 0.3154044836871817
train_accuracy: 0.8956797122955322
train_f1_score: 0.8871111273765564
validation loss : 0.3679352754608114
val_accuracy: 0.8814165592193604
val_f1_score: 0.8710653781890869
#########################


 17%|█▋        | 10/60 [27:13<2:15:59, 163.20s/it]

#########################
training loss: 0.3115820486825813
train_accuracy: 0.8970664739608765
train_f1_score: 0.8888110518455505
validation loss : 0.3393267026924072
val_accuracy: 0.8902619481086731
val_f1_score: 0.8802422285079956
#########################


 18%|█▊        | 11/60 [29:56<2:13:15, 163.16s/it]

#########################
training loss: 0.29894667220748
train_accuracy: 0.9011328816413879
train_f1_score: 0.893094003200531
validation loss : 0.36778156015467134
val_accuracy: 0.8867027163505554
val_f1_score: 0.8715475797653198
#########################


 20%|██        | 12/60 [32:40<2:10:36, 163.26s/it]

#########################
training loss: 0.2897338869806662
train_accuracy: 0.9053362011909485
train_f1_score: 0.8965623378753662
validation loss : 0.35452977079131925
val_accuracy: 0.880494236946106
val_f1_score: 0.8736609220504761
#########################


 22%|██▏       | 13/60 [35:23<2:07:50, 163.20s/it]

#########################
training loss: 0.28197246479374466
train_accuracy: 0.9077281355857849
train_f1_score: 0.899604320526123
validation loss : 0.37937393134578745
val_accuracy: 0.8851537108421326
val_f1_score: 0.8774651288986206
#########################


 23%|██▎       | 14/60 [38:06<2:05:05, 163.17s/it]

#########################
training loss: 0.27518541195148655
train_accuracy: 0.908929169178009
train_f1_score: 0.9007154703140259
validation loss : 0.32691813537732084
val_accuracy: 0.8901875019073486
val_f1_score: 0.8834104537963867
#########################


 25%|██▌       | 15/60 [40:49<2:02:19, 163.09s/it]

#########################
training loss: 0.26782796150032484
train_accuracy: 0.911885142326355
train_f1_score: 0.9039446115493774
validation loss : 0.3561975881457329
val_accuracy: 0.8799182772636414
val_f1_score: 0.8744219541549683
#########################


 27%|██▋       | 16/60 [43:33<1:59:44, 163.29s/it]

#########################
training loss: 0.2634116089543268
train_accuracy: 0.9137076139450073
train_f1_score: 0.9060928225517273
validation loss : 0.34185320326516816
val_accuracy: 0.8951003551483154
val_f1_score: 0.8844327926635742
#########################


 28%|██▊       | 17/60 [46:16<1:57:00, 163.28s/it]

#########################
training loss: 0.2630892631370394
train_accuracy: 0.9119928479194641
train_f1_score: 0.9040103554725647
validation loss : 0.3364093080691412
val_accuracy: 0.8919230103492737
val_f1_score: 0.8822560906410217
#########################


 30%|███       | 18/60 [49:00<1:54:20, 163.34s/it]

#########################
training loss: 0.25470490519329203
train_accuracy: 0.9170425534248352
train_f1_score: 0.9093701839447021
validation loss : 0.3494340869340491
val_accuracy: 0.8924699425697327
val_f1_score: 0.8840936422348022
#########################


 32%|███▏      | 19/60 [51:43<1:51:32, 163.23s/it]

#########################
training loss: 0.2527380825814891
train_accuracy: 0.9170505404472351
train_f1_score: 0.9097360968589783
validation loss : 0.3389211665601172
val_accuracy: 0.8815774321556091
val_f1_score: 0.8811331391334534
#########################


 33%|███▎      | 20/60 [54:26<1:48:52, 163.31s/it]

#########################
training loss: 0.23890126145538937
train_accuracy: 0.9209684133529663
train_f1_score: 0.9131979942321777
validation loss : 0.3297909198606268
val_accuracy: 0.8871961832046509
val_f1_score: 0.882787823677063
#########################


 35%|███▌      | 21/60 [57:11<1:46:23, 163.67s/it]

#########################
training loss: 0.23500788496656383
train_accuracy: 0.9220826029777527
train_f1_score: 0.9147211313247681
validation loss : 0.32630078367730403
val_accuracy: 0.8928298950195312
val_f1_score: 0.8851272463798523
#########################


 37%|███▋      | 22/60 [59:54<1:43:36, 163.59s/it]

#########################
training loss: 0.23198215763054775
train_accuracy: 0.9240731000900269
train_f1_score: 0.91620934009552
validation loss : 0.32762577683261945
val_accuracy: 0.8943313956260681
val_f1_score: 0.8843120336532593
#########################


 38%|███▊      | 23/60 [1:02:37<1:40:49, 163.50s/it]

#########################
training loss: 0.2230685958703097
train_accuracy: 0.9256229400634766
train_f1_score: 0.9184365272521973
validation loss : 0.3496795007100342
val_accuracy: 0.8909634351730347
val_f1_score: 0.8842213153839111
#########################


 40%|████      | 24/60 [1:05:21<1:38:07, 163.54s/it]

#########################
training loss: 0.2207583234729029
train_accuracy: 0.9278699159622192
train_f1_score: 0.9198514819145203
validation loss : 0.323334839934787
val_accuracy: 0.8980064988136292
val_f1_score: 0.8897956609725952
#########################


 42%|████▏     | 25/60 [1:08:04<1:35:21, 163.47s/it]

#########################
training loss: 0.2145512694561923
train_accuracy: 0.9296973347663879
train_f1_score: 0.9227617383003235
validation loss : 0.31077250058875017
val_accuracy: 0.8986194133758545
val_f1_score: 0.8912683725357056
#########################


 43%|████▎     | 26/60 [1:10:47<1:32:32, 163.32s/it]

#########################
training loss: 0.21108104565928584
train_accuracy: 0.9309092164039612
train_f1_score: 0.9241523742675781
validation loss : 0.31159783979045586
val_accuracy: 0.9081543684005737
val_f1_score: 0.8949733972549438
#########################


 45%|████▌     | 27/60 [1:13:31<1:29:53, 163.43s/it]

#########################
training loss: 0.20838112483936438
train_accuracy: 0.9312685132026672
train_f1_score: 0.9238835573196411
validation loss : 0.317154383986977
val_accuracy: 0.9064380526542664
val_f1_score: 0.8945971131324768
#########################


 47%|████▋     | 28/60 [1:16:14<1:27:04, 163.28s/it]

#########################
training loss: 0.2033417454130868
train_accuracy: 0.9323380589485168
train_f1_score: 0.9259499311447144
validation loss : 0.332357747784109
val_accuracy: 0.9031924605369568
val_f1_score: 0.8893856406211853
#########################


 48%|████▊     | 29/60 [1:18:57<1:24:18, 163.17s/it]

#########################
training loss: 0.20197653648024338
train_accuracy: 0.9319815635681152
train_f1_score: 0.9252028465270996
validation loss : 0.3073870747330341
val_accuracy: 0.9008771777153015
val_f1_score: 0.8930427432060242
#########################


 50%|█████     | 30/60 [1:21:40<1:21:32, 163.08s/it]

#########################
training loss: 0.19829467611652796
train_accuracy: 0.9339228868484497
train_f1_score: 0.9268865585327148
validation loss : 0.31479358496078363
val_accuracy: 0.9077026844024658
val_f1_score: 0.8911954164505005
#########################


 52%|█████▏    | 31/60 [1:24:23<1:18:49, 163.08s/it]

#########################
training loss: 0.1917570425906059
train_accuracy: 0.9372875690460205
train_f1_score: 0.9303452372550964
validation loss : 0.3087385778467283
val_accuracy: 0.9078048467636108
val_f1_score: 0.8964628577232361
#########################


 53%|█████▎    | 32/60 [1:27:06<1:16:05, 163.07s/it]

#########################
training loss: 0.1932359884229383
train_accuracy: 0.936916172504425
train_f1_score: 0.9304003715515137
validation loss : 0.28468405069611596
val_accuracy: 0.9198958873748779
val_f1_score: 0.9040467739105225
#########################


 55%|█████▌    | 33/60 [1:29:49<1:13:23, 163.08s/it]

#########################
training loss: 0.18201281745700026
train_accuracy: 0.9405566453933716
train_f1_score: 0.9335569739341736
validation loss : 0.2936022647834839
val_accuracy: 0.9144868850708008
val_f1_score: 0.8997933268547058
#########################


 57%|█████▋    | 34/60 [1:32:31<1:10:35, 162.90s/it]

#########################
training loss: 0.18102232542073696
train_accuracy: 0.940389096736908
train_f1_score: 0.9331119060516357
validation loss : 0.30962017503507594
val_accuracy: 0.9140994548797607
val_f1_score: 0.8936704993247986
#########################


 58%|█████▊    | 35/60 [1:35:14<1:07:54, 162.98s/it]

#########################
training loss: 0.17621359654233953
train_accuracy: 0.9415411949157715
train_f1_score: 0.9346655607223511
validation loss : 0.2962948909253939
val_accuracy: 0.9211128950119019
val_f1_score: 0.9023481011390686
#########################


 60%|██████    | 36/60 [1:37:57<1:05:12, 163.01s/it]

#########################
training loss: 0.17204395427888322
train_accuracy: 0.9444724321365356
train_f1_score: 0.9375982284545898
validation loss : 0.2750424177035795
val_accuracy: 0.9055168032646179
val_f1_score: 0.9035176634788513
#########################


 62%|██████▏   | 37/60 [1:40:40<1:02:28, 162.96s/it]

#########################
training loss: 0.16588986710611933
train_accuracy: 0.9463545680046082
train_f1_score: 0.9400557279586792
validation loss : 0.2940605257023522
val_accuracy: 0.9078581929206848
val_f1_score: 0.9029098153114319
#########################


 63%|██████▎   | 38/60 [1:43:23<59:44, 162.93s/it]  

#########################
training loss: 0.16418292142137444
train_accuracy: 0.9454382061958313
train_f1_score: 0.9392470717430115
validation loss : 0.2886722904781923
val_accuracy: 0.9076282382011414
val_f1_score: 0.9033184051513672
#########################


 65%|██████▌   | 39/60 [1:46:06<56:59, 162.83s/it]

#########################
training loss: 0.16184458951071157
train_accuracy: 0.9474246501922607
train_f1_score: 0.9407219886779785
validation loss : 0.27733837725951316
val_accuracy: 0.9070038795471191
val_f1_score: 0.9055137634277344
#########################


 67%|██████▋   | 40/60 [1:48:49<54:16, 162.81s/it]

#########################
training loss: 0.15814255314545292
train_accuracy: 0.9481683969497681
train_f1_score: 0.9418627619743347
validation loss : 0.2710557428736847
val_accuracy: 0.9178174138069153
val_f1_score: 0.9077887535095215
#########################


 68%|██████▊   | 41/60 [1:51:31<51:33, 162.82s/it]

#########################
training loss: 0.15488939175799685
train_accuracy: 0.9489520788192749
train_f1_score: 0.9432083368301392
validation loss : 0.2960408933672076
val_accuracy: 0.9145895838737488
val_f1_score: 0.9037578105926514
#########################


 70%|███████   | 42/60 [1:54:14<48:52, 162.90s/it]

#########################
training loss: 0.15323804696987933
train_accuracy: 0.9499512314796448
train_f1_score: 0.9437553286552429
validation loss : 0.27117619629130296
val_accuracy: 0.9197510480880737
val_f1_score: 0.9113265872001648
#########################


 72%|███████▏  | 43/60 [1:56:57<46:08, 162.84s/it]

#########################
training loss: 0.15054632118761035
train_accuracy: 0.9515513777732849
train_f1_score: 0.9451776146888733
validation loss : 0.27245274382967055
val_accuracy: 0.9122451543807983
val_f1_score: 0.9063056707382202
#########################


 73%|███████▎  | 44/60 [1:59:40<43:27, 162.97s/it]

#########################
training loss: 0.1469227932393551
train_accuracy: 0.9518728852272034
train_f1_score: 0.9462850093841553
validation loss : 0.2762318446254688
val_accuracy: 0.9173176884651184
val_f1_score: 0.9117028713226318
#########################


 75%|███████▌  | 45/60 [2:02:23<40:44, 162.97s/it]

#########################
training loss: 0.14069874037515515
train_accuracy: 0.9545965790748596
train_f1_score: 0.9484971165657043
validation loss : 0.3060370916170431
val_accuracy: 0.9074616432189941
val_f1_score: 0.8915858864784241
#########################


 77%|███████▋  | 46/60 [2:05:06<38:01, 162.94s/it]

#########################
training loss: 0.1356954123007516
train_accuracy: 0.9562616944313049
train_f1_score: 0.9505376815795898
validation loss : 0.29638452184591313
val_accuracy: 0.9090223908424377
val_f1_score: 0.9064628481864929
#########################


 78%|███████▊  | 47/60 [2:07:49<35:17, 162.90s/it]

#########################
training loss: 0.13516864465807996
train_accuracy: 0.956018328666687
train_f1_score: 0.9502392411231995
validation loss : 0.30026422884870085
val_accuracy: 0.9213385581970215
val_f1_score: 0.8990538716316223
#########################


 80%|████████  | 48/60 [2:10:32<32:34, 162.89s/it]

#########################
training loss: 0.13097185940392272
train_accuracy: 0.9569337964057922
train_f1_score: 0.9508448243141174
validation loss : 0.2804393193136293
val_accuracy: 0.9057677984237671
val_f1_score: 0.9044317007064819
#########################


 82%|████████▏ | 49/60 [2:13:14<29:50, 162.75s/it]

#########################
training loss: 0.12747953330416276
train_accuracy: 0.9588859677314758
train_f1_score: 0.9534929990768433
validation loss : 0.3061195734544849
val_accuracy: 0.8996095061302185
val_f1_score: 0.9005774259567261
#########################


 83%|████████▎ | 50/60 [2:15:57<27:07, 162.70s/it]

#########################
training loss: 0.12318630011697292
train_accuracy: 0.9595198631286621
train_f1_score: 0.9539844989776611
validation loss : 0.3029644625353898
val_accuracy: 0.9145586490631104
val_f1_score: 0.9054295420646667
#########################


 85%|████████▌ | 51/60 [2:18:40<24:24, 162.74s/it]

#########################
training loss: 0.11928270759829189
train_accuracy: 0.9617543816566467
train_f1_score: 0.9566172957420349
validation loss : 0.26274917668378944
val_accuracy: 0.924534022808075
val_f1_score: 0.9124342203140259
#########################


 87%|████████▋ | 52/60 [2:21:22<21:40, 162.52s/it]

#########################
training loss: 0.11792728557834274
train_accuracy: 0.9619300365447998
train_f1_score: 0.9567068815231323
validation loss : 0.28329189145501626
val_accuracy: 0.9163578748703003
val_f1_score: 0.9098788499832153
#########################


 88%|████████▊ | 53/60 [2:24:04<18:57, 162.49s/it]

#########################
training loss: 0.11484811850458065
train_accuracy: 0.9627971649169922
train_f1_score: 0.9571723341941833
validation loss : 0.2762122264541421
val_accuracy: 0.9200558066368103
val_f1_score: 0.9138039350509644
#########################


 90%|█████████ | 54/60 [2:26:46<16:14, 162.42s/it]

#########################
training loss: 0.1117161799366051
train_accuracy: 0.963699460029602
train_f1_score: 0.9585846066474915
validation loss : 0.28848923666151705
val_accuracy: 0.9287794232368469
val_f1_score: 0.9127436280250549
#########################


 92%|█████████▏| 55/60 [2:29:30<13:34, 162.81s/it]

#########################
training loss: 0.10757357150559446
train_accuracy: 0.9648270606994629
train_f1_score: 0.9601597785949707
validation loss : 0.26386095967857126
val_accuracy: 0.9280036091804504
val_f1_score: 0.9163966178894043
#########################


 93%|█████████▎| 56/60 [2:32:13<10:51, 162.77s/it]

#########################
training loss: 0.10455742382432164
train_accuracy: 0.9658377170562744
train_f1_score: 0.9611656665802002
validation loss : 0.2779472472925558
val_accuracy: 0.9126416444778442
val_f1_score: 0.9091469645500183
#########################


 95%|█████████▌| 57/60 [2:39:28<12:13, 244.46s/it]

#########################
training loss: 0.10096696730906568
train_accuracy: 0.9676755666732788
train_f1_score: 0.9628610610961914
validation loss : 0.2589647296406593
val_accuracy: 0.9250345230102539
val_f1_score: 0.9185603260993958
#########################


 95%|█████████▌| 57/60 [2:45:42<08:43, 174.42s/it]


RuntimeError: unique_by_key: failed to synchronize: cudaErrorLaunchFailure: unspecified launch failure

# Inference 

In [None]:
# let's make the damn submission
from src.pytorch_modular.pytorch_utilities import load_model
# base_model = SeqClassModel(in_features=768, hidden_size=128, num_classes=6)
# base_model = load_model(base_model=base_model, path=os.path.join(HOME, 'saved_models', '9-17-15-10.pt'))
# let's create a dataset object really quick:
class TestReviewDS(Dataset):
    def __init__(self, data: pd.DataFrame) -> None:
        super().__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index) -> tuple[str, int]:
        return self.data.iloc[index, 1]

# we need a different callate_function
def test_collate_function(batch):
    embeddings = MODEL(**TOKENIZER(batch, padding=True, return_tensors='pt').to(NOTEBOOK_DEVICE)).last_hidden_state # make sure to return tensors
    return embeddings.to(NOTEBOOK_DEVICE)
    
# let's set the random seed

torch.manual_seed(69)

test_ds = TestReviewDS(data=df_test)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, collate_fn=test_collate_function)
# next(iter(test_loader)).shape
predictions = cls.inference(base_model, inference_source_data=test_loader, return_tensor='list')
# convert the numerical labels to the string ones
predictions = [idx2cat[p] for p in predictions]

In [None]:
submission = pd.DataFrame(data={"id": df_test['id'].tolist(), "Category": predictions})
sub_dir = os.path.join(HOME, 'submissions')
submission.to_csv(os.path.join(sub_dir, f'sub_{len(os.listdir(sub_dir)) + 1}.csv'), index=False)