<a href="https://colab.research.google.com/github/astromad/MyDeepLearningRepo/blob/master/ProductClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!rm -rf Classification_cache
!rm -rf results_PT
!rm -rf logs_PT


In [2]:
!pip install transformers



In [3]:
# Read the Dataset and load in Pandas Dataframe
import pandas as pd
df = pd.read_csv("/content/drive/My Drive/ColabData/Amazon.csv",
                encoding="ISO-8859-1", error_bad_lines=False)

data = df[['category', 'label_title', 'label_description']]
data.dropna(subset=['category'], inplace=True)
print(data.head(3))


                category  ...                                  label_description
0  Headphone Accessories  ...  The pocket-size Koss 3-Band Equalizer delivers...
1     Inkjet Printer Ink  ...  Kodak Black Ink Cartridge 10B is a standard bl...
2  Computers Accessories  ...  1GB - 333MHz DDR333 PC2700 - DDR SDRAM - 184-p...

[3 rows x 3 columns]


In [4]:
# Remove rows if category is null
data.dropna(subset=['category'], inplace=True)

In [5]:
# Convert category description to numerical category ID
encode_dict={}
def encode_label(x):
    if x not in encode_dict.keys():
        encode_dict[x]=len(encode_dict)
    return encode_dict[x]

data['encoded_category'] = data['category'].apply(lambda x: encode_label(x))

In [6]:
# create new dataframe and merge label title and description
newData=pd.DataFrame()
newData['desc']=data['label_title'] +' '+ data['label_description'] 
newData['encoded_category']=data['encoded_category']
#newData['category']=data['category']


In [7]:
# drop any rows with description is null
newData.dropna(subset=['desc'], inplace=True)
nan_rows = newData[newData.isnull().T.any()]
print(nan_rows)

Empty DataFrame
Columns: [desc, encoded_category]
Index: []


In [8]:
newData.loc[20,'desc']

'Energizer Max Alkaline Batteries Energizer Max Batteries  Energizer Max batteries provide long-lasting dependable power for your everyday devices. Energizer MAX is the perfect alkaline battery when you need power that lasts and performance you can count on. A flashlight to lead the way. A smoke detector to alert you. Or simply a radio to play your favorite song. Also ideal for games and toys clocks and much more.                      World s 1st Zero Mercury Alkaline Battery Energizer is a leader in the industry in powering people s lives responsibly Commercially available since 1991                 Up to 7 years shelf life 7 years on AA AAA C and D and up to 5 years shelf life on 9V               Ideal for the devices you use every day from toys to video game controllers to flashlights             Available in multiple cell sizes to power your whole house.          Energizer Max AA Energizer Max AAA Energizer Max C Energizer Max D Energizer Max 9V Energizer Stands for Innovation and 

In [9]:
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
stop = stopwords.words('english')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [10]:
# Preprocessing on description text data, remove stop words, remove spaces, lowercase
# note: we are not lemmatize as Bert will take care of it
newData['desc']=newData.desc.str.replace("[^\w\s]", "").str.lower()
#newData['desc']=newData.desc.str.replace('\d+', '')
#newData['desc']=newData['desc'].apply(lambda x: [item for item in x.split() if item not in stop])
newData['desc']=newData['desc'].apply(lambda x: ' '.join([item for item in x.split() if item not in stop]))

In [11]:
newData.loc[20,'desc']

'energizer max alkaline batteries energizer max batteries energizer max batteries provide longlasting dependable power everyday devices energizer max perfect alkaline battery need power lasts performance count flashlight lead way smoke detector alert simply radio play favorite song also ideal games toys clocks much world 1st zero mercury alkaline battery energizer leader industry powering people lives responsibly commercially available since 1991 7 years shelf life 7 years aa aaa c 5 years shelf life 9v ideal devices use every day toys video game controllers flashlights available multiple cell sizes power whole house energizer max aa energizer max aaa energizer max c energizer max energizer max 9v energizer stands innovation performance energizer global leader dynamic business providing power solutions full portfolio products including energizer brand battery products energizer max premium alkaline energizer ultimate lithium energizer advanced lithium rechargeable batteries charging sy

In [12]:
# Helper functions to convert category ID to numerical and back
from future.utils import iteritems
label2idx = {t: i for i, t in enumerate(encode_dict)}
idx2label = {v: k for k, v in iteritems(label2idx)}

In [13]:
#print(newData)

In [14]:
# findout number of categories
ClassMax=newData['encoded_category'].max()
print(ClassMax)


705


In [15]:
#data['encoded_category'].describe()

In [16]:
# Create train and test data split
train_size = 0.8
train_dataset=newData.sample(frac=train_size,random_state=200)
test_dataset=newData.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)


print("FULL Dataset: {}".format(newData.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(test_dataset.shape))

FULL Dataset: (20701, 2)
TRAIN Dataset: (16561, 2)
TEST Dataset: (4140, 2)


In [17]:
MAX_LEN = 128
LEARNING_RATE = 3e-02

In [18]:
from transformers import (
    AutoConfig,
    AutoTokenizer,
    #BertTokenizer
)
model_args = dict()
model_args['model_name'] = 'bert-base-uncased' 
model_args['cache_dir'] = "Classification_cache/"
model_args['do_basic_tokenize'] = False

config = AutoConfig.from_pretrained(
    model_args['model_name'],
    # num_labels=num_labels,
    # id2label=label_map,
    # label2id={label: i for i, label in enumerate(labels)},
    cache_dir=model_args['cache_dir'],
    return_dict=True,
    num_labels=ClassMax+1
    #num_labels=1
)

tokenizer = AutoTokenizer.from_pretrained(
    model_args['model_name'],
    cache_dir=model_args['cache_dir'],
    is_pretokenized=model_args['do_basic_tokenize'],
    do_basic_tokenize = model_args['do_basic_tokenize']
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [19]:
import torch
import re
class TorchClassificationDataset(torch.utils.data.Dataset):
    def __init__(self,dataset,max_len):
        self.len = len(dataset)
        self.data = dataset
        self.max_len=max_len
    def __getitem__(self, idx):
        description = str(self.data.desc[idx])
        #description = " ".join(description.split())
        #print(description)
        description = description[:self.max_len]
        #description = re.sub('[^a-zA-Z0-9\n\.]', ' ', description)
        #description = " ".join(description.split())
        #print(description)
        inputs = tokenizer.encode_plus(
            description,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            #pad_to_max_length=True,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )
        item ={}
        item['input_ids']=torch.tensor(inputs['input_ids'], dtype=torch.long)
        item['token_type_ids']=torch.tensor(inputs['token_type_ids'], dtype=torch.long)
        item['attention_mask']=torch.tensor(inputs['attention_mask'], dtype=torch.long)
        item['labels'] = torch.tensor(self.data.encoded_category[idx], dtype=torch.long)
        return item

    def __len__(self):
        return self.len

In [20]:
def createDataset(framework='pt'):
  if framework=='pt':
    train_ds = TorchClassificationDataset(train_dataset,MAX_LEN)
    test_ds= TorchClassificationDataset(test_dataset,MAX_LEN)
  return train_ds,test_ds

In [21]:
train_ds,test_ds = createDataset('pt')
print('One record of Training dataset')
print(train_dataset.loc[1,'desc'])
print('----')
print(train_ds[1])


One record of Training dataset
vizio xcp200 high performance screen cleaning kit wipe without worry high performance screen cleaning kit vizio safely clean led lcd plasma laptop screen microfiber cloth antidrip nonstreak cleaner safely clean led lcd plasma laptop screen click larger image protect picture cleaning solution ecofriendly cleaning solution alcohol ammonia free effectively removes fingerprints grease repelling dust leave streaks mess picture end cleaning drip unreachable nooks crannies spray microfiber cloth gentle antibacterial microfiber cloth crucial protecting investment removing gunk without scratching delicate surfaces clean vizio kit perfect cleaning solution hdtv monitor laptop cell phone digital camera 3d glasses box bottle screen cleaning solution microfiber cloth
----
{'input_ids': tensor([  101, 26619,  3695,  1060, 21906, 28332,  2152,  2836,  3898,  9344,
         8934, 13387,  2302,  4737,  2152,  2836,  3898,  9344,  8934, 26619,
         3695,  9689,  4550, 

In [22]:
!pip install seqeval



In [23]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='micro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [24]:
# from torch import cuda
# device = 'cuda' if cuda.is_available() else 'cpu'

In [25]:
from transformers import (
    AutoModelForSequenceClassification,
    #BertForSequenceClassification,
    Trainer,
    TrainingArguments
)
model = AutoModelForSequenceClassification.from_pretrained(
    model_args['model_name'],
    config=config,
    cache_dir=model_args['cache_dir'],
)
training_args = TrainingArguments(
    output_dir='./results_PT',          
    num_train_epochs=20,              
    per_device_train_batch_size=32,  
    per_device_eval_batch_size=32,   
    warmup_steps=500,                
    weight_decay=0.01,               
    logging_dir='./logs_PT',            
    logging_steps=3,
    #learning_rate=LEARNING_RATE
)

trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=train_ds,        
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,  
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [26]:
#model.to(device)

In [27]:
# Lets tain the model now
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=20.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 6.559370676676433, 'learning_rate': 3.0000000000000004e-07, 'epoch': 0.005791505791505791, 'total_flos': 8111934554112, 'step': 3}
{'loss': 6.586741129557292, 'learning_rate': 6.000000000000001e-07, 'epoch': 0.011583011583011582, 'total_flos': 16223869108224, 'step': 6}
{'loss': 6.6088409423828125, 'learning_rate': 9e-07, 'epoch': 0.017374517374517374, 'total_flos': 24335803662336, 'step': 9}
{'loss': 6.589232126871745, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.023166023166023165, 'total_flos': 32447738216448, 'step': 12}
{'loss': 6.579256693522136, 'learning_rate': 1.5e-06, 'epoch': 0.02895752895752896, 'total_flos': 40559672770560, 'step': 15}
{'loss': 6.6060536702473955, 'learning_rate': 1.8e-06, 'epoch': 0.03474903474903475, 'total_flos': 48671607324672, 'step': 18}
{'loss': 6.538370768229167, 'learning_rate': 2.1000000000000002e-06, 'epoch': 0.04054054054054054, 'total_flos': 56783541878784, 'step': 21}
{'loss': 6.5906728108723955, 'learning_rate': 2.40000000000



{'loss': 4.056722005208333, 'learning_rate': 4.99949290060852e-05, 'epoch': 0.9671814671814671, 'total_flos': 1354693070536704, 'step': 501}
{'loss': 4.015787760416667, 'learning_rate': 4.9979716024340775e-05, 'epoch': 0.972972972972973, 'total_flos': 1362805005090816, 'step': 504}
{'loss': 3.9918619791666665, 'learning_rate': 4.996450304259635e-05, 'epoch': 0.9787644787644788, 'total_flos': 1370916939644928, 'step': 507}
{'loss': 3.958984375, 'learning_rate': 4.994929006085193e-05, 'epoch': 0.9845559845559846, 'total_flos': 1379028874199040, 'step': 510}
{'loss': 4.061116536458333, 'learning_rate': 4.993407707910751e-05, 'epoch': 0.9903474903474904, 'total_flos': 1387140808753152, 'step': 513}
{'loss': 4.217692057291667, 'learning_rate': 4.9918864097363085e-05, 'epoch': 0.9961389961389961, 'total_flos': 1395252743307264, 'step': 516}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 3.5071614583333335, 'learning_rate': 4.990365111561866e-05, 'epoch': 1.001930501930502, 'total_flos': 1402097188087296, 'step': 519}
{'loss': 4.153645833333333, 'learning_rate': 4.988843813387424e-05, 'epoch': 1.0077220077220077, 'total_flos': 1410209122641408, 'step': 522}
{'loss': 4.145345052083333, 'learning_rate': 4.987322515212982e-05, 'epoch': 1.0135135135135136, 'total_flos': 1418321057195520, 'step': 525}
{'loss': 4.074381510416667, 'learning_rate': 4.9858012170385396e-05, 'epoch': 1.0193050193050193, 'total_flos': 1426432991749632, 'step': 528}
{'loss': 3.653564453125, 'learning_rate': 4.9842799188640973e-05, 'epoch': 1.0250965250965252, 'total_flos': 1434544926303744, 'step': 531}
{'loss': 3.2071940104166665, 'learning_rate': 4.982758620689655e-05, 'epoch': 1.0308880308880308, 'total_flos': 1442656860857856, 'step': 534}
{'loss': 3.897705078125, 'learning_rate': 4.981237322515213e-05, 'epoch': 1.0366795366795367, 'total_flos': 1450768795411968, 'step': 537}
{'loss': 

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 3.1917317708333335, 'learning_rate': 4.7271805273833674e-05, 'epoch': 2.003861003861004, 'total_flos': 2804194376174592, 'step': 1038}
{'loss': 2.5963541666666665, 'learning_rate': 4.725659229208925e-05, 'epoch': 2.0096525096525095, 'total_flos': 2812306310728704, 'step': 1041}
{'loss': 3.03857421875, 'learning_rate': 4.724137931034483e-05, 'epoch': 2.0154440154440154, 'total_flos': 2820418245282816, 'step': 1044}
{'loss': 2.7374674479166665, 'learning_rate': 4.722616632860041e-05, 'epoch': 2.0212355212355213, 'total_flos': 2828530179836928, 'step': 1047}
{'loss': 3.03271484375, 'learning_rate': 4.7210953346855984e-05, 'epoch': 2.027027027027027, 'total_flos': 2836642114391040, 'step': 1050}
{'loss': 2.8147786458333335, 'learning_rate': 4.719574036511156e-05, 'epoch': 2.0328185328185326, 'total_flos': 2844754048945152, 'step': 1053}
{'loss': 2.4290364583333335, 'learning_rate': 4.718052738336714e-05, 'epoch': 2.0386100386100385, 'total_flos': 2852865983499264, 'step': 1056}
{'

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 2.39990234375, 'learning_rate': 4.4639959432048685e-05, 'epoch': 3.005791505791506, 'total_flos': 4206291564261888, 'step': 1557}
{'loss': 2.0983072916666665, 'learning_rate': 4.462474645030426e-05, 'epoch': 3.011583011583012, 'total_flos': 4214403498816000, 'step': 1560}
{'loss': 1.9928385416666667, 'learning_rate': 4.460953346855984e-05, 'epoch': 3.0173745173745172, 'total_flos': 4222515433370112, 'step': 1563}
{'loss': 2.3343098958333335, 'learning_rate': 4.459432048681542e-05, 'epoch': 3.023166023166023, 'total_flos': 4230627367924224, 'step': 1566}
{'loss': 2.0139973958333335, 'learning_rate': 4.4579107505070995e-05, 'epoch': 3.028957528957529, 'total_flos': 4238739302478336, 'step': 1569}
{'loss': 2.5616861979166665, 'learning_rate': 4.456389452332657e-05, 'epoch': 3.034749034749035, 'total_flos': 4246851237032448, 'step': 1572}
{'loss': 2.10205078125, 'learning_rate': 4.454868154158216e-05, 'epoch': 3.0405405405405403, 'total_flos': 4254963171586560, 'step': 1575}
{'los

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 2.130859375, 'learning_rate': 4.202332657200812e-05, 'epoch': 4.001930501930502, 'total_flos': 5600276817795072, 'step': 2073}
{'loss': 1.8779296875, 'learning_rate': 4.2008113590263695e-05, 'epoch': 4.007722007722008, 'total_flos': 5608388752349184, 'step': 2076}
{'loss': 1.61767578125, 'learning_rate': 4.199290060851927e-05, 'epoch': 4.013513513513513, 'total_flos': 5616500686903296, 'step': 2079}
{'loss': 1.6788736979166667, 'learning_rate': 4.197768762677485e-05, 'epoch': 4.019305019305019, 'total_flos': 5624612621457408, 'step': 2082}
{'loss': 1.5592447916666667, 'learning_rate': 4.196247464503043e-05, 'epoch': 4.025096525096525, 'total_flos': 5632724556011520, 'step': 2085}
{'loss': 1.8313802083333333, 'learning_rate': 4.1947261663286006e-05, 'epoch': 4.030888030888031, 'total_flos': 5640836490565632, 'step': 2088}
{'loss': 1.7628580729166667, 'learning_rate': 4.1932048681541584e-05, 'epoch': 4.036679536679537, 'total_flos': 5648948425119744, 'step': 2091}
{'loss': 1.896

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 1.7923177083333333, 'learning_rate': 3.939148073022312e-05, 'epoch': 5.003861003861004, 'total_flos': 7002374005882368, 'step': 2592}
{'loss': 0.888671875, 'learning_rate': 3.93762677484787e-05, 'epoch': 5.0096525096525095, 'total_flos': 7010485940436480, 'step': 2595}
{'loss': 1.1652018229166667, 'learning_rate': 3.9361054766734284e-05, 'epoch': 5.015444015444015, 'total_flos': 7018597874990592, 'step': 2598}
{'loss': 1.3668619791666667, 'learning_rate': 3.934584178498986e-05, 'epoch': 5.021235521235521, 'total_flos': 7026709809544704, 'step': 2601}
{'loss': 1.1373697916666667, 'learning_rate': 3.933062880324544e-05, 'epoch': 5.027027027027027, 'total_flos': 7034821744098816, 'step': 2604}
{'loss': 0.9944661458333334, 'learning_rate': 3.931541582150102e-05, 'epoch': 5.032818532818533, 'total_flos': 7042933678652928, 'step': 2607}
{'loss': 1.3533528645833333, 'learning_rate': 3.9300202839756594e-05, 'epoch': 5.038610038610039, 'total_flos': 7051045613207040, 'step': 2610}
{'lo

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.958984375, 'learning_rate': 3.675963488843814e-05, 'epoch': 6.005791505791506, 'total_flos': 8404471193969664, 'step': 3111}
{'loss': 1.0035807291666667, 'learning_rate': 3.674442190669372e-05, 'epoch': 6.011583011583012, 'total_flos': 8412583128523776, 'step': 3114}
{'loss': 1.009765625, 'learning_rate': 3.6729208924949295e-05, 'epoch': 6.017374517374518, 'total_flos': 8420695063077888, 'step': 3117}
{'loss': 0.9143880208333334, 'learning_rate': 3.671399594320487e-05, 'epoch': 6.023166023166024, 'total_flos': 8428806997632000, 'step': 3120}
{'loss': 1.1064453125, 'learning_rate': 3.669878296146045e-05, 'epoch': 6.028957528957529, 'total_flos': 8436918932186112, 'step': 3123}
{'loss': 0.8512369791666666, 'learning_rate': 3.668356997971603e-05, 'epoch': 6.0347490347490345, 'total_flos': 8445030866740224, 'step': 3126}
{'loss': 1.1110026041666667, 'learning_rate': 3.6668356997971605e-05, 'epoch': 6.04054054054054, 'total_flos': 8453142801294336, 'step': 3129}
{'loss': 1.201497

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 1.0100911458333333, 'learning_rate': 3.4143002028397566e-05, 'epoch': 7.001930501930502, 'total_flos': 9798456447502848, 'step': 3627}
{'loss': 1.046875, 'learning_rate': 3.4127789046653144e-05, 'epoch': 7.007722007722008, 'total_flos': 9806568382056960, 'step': 3630}
{'loss': 1.0716145833333333, 'learning_rate': 3.411257606490872e-05, 'epoch': 7.013513513513513, 'total_flos': 9814680316611072, 'step': 3633}
{'loss': 0.8118489583333334, 'learning_rate': 3.40973630831643e-05, 'epoch': 7.019305019305019, 'total_flos': 9822792251165184, 'step': 3636}
{'loss': 0.9430338541666666, 'learning_rate': 3.4082150101419876e-05, 'epoch': 7.025096525096525, 'total_flos': 9830904185719296, 'step': 3639}
{'loss': 0.9156901041666666, 'learning_rate': 3.4066937119675454e-05, 'epoch': 7.030888030888031, 'total_flos': 9839016120273408, 'step': 3642}
{'loss': 1.0755208333333333, 'learning_rate': 3.405172413793103e-05, 'epoch': 7.036679536679537, 'total_flos': 9847128054827520, 'step': 3645}
{'loss

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.8606770833333334, 'learning_rate': 3.151115618661258e-05, 'epoch': 8.003861003861005, 'total_flos': 11200553635590144, 'step': 4146}
{'loss': 0.6064453125, 'learning_rate': 3.1495943204868154e-05, 'epoch': 8.00965250965251, 'total_flos': 11208665570144256, 'step': 4149}
{'loss': 0.6184895833333334, 'learning_rate': 3.148073022312373e-05, 'epoch': 8.015444015444016, 'total_flos': 11216777504698368, 'step': 4152}
{'loss': 0.7327473958333334, 'learning_rate': 3.146551724137931e-05, 'epoch': 8.021235521235521, 'total_flos': 11224889439252480, 'step': 4155}
{'loss': 0.8616536458333334, 'learning_rate': 3.145030425963489e-05, 'epoch': 8.027027027027026, 'total_flos': 11233001373806592, 'step': 4158}
{'loss': 0.7526041666666666, 'learning_rate': 3.1435091277890465e-05, 'epoch': 8.032818532818533, 'total_flos': 11241113308360704, 'step': 4161}
{'loss': 0.6910807291666666, 'learning_rate': 3.141987829614604e-05, 'epoch': 8.038610038610038, 'total_flos': 11249225242914816, 'step': 416

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.6110026041666666, 'learning_rate': 2.8879310344827588e-05, 'epoch': 9.005791505791505, 'total_flos': 12602650823677440, 'step': 4665}
{'loss': 0.384765625, 'learning_rate': 2.8864097363083165e-05, 'epoch': 9.011583011583012, 'total_flos': 12610762758231552, 'step': 4668}
{'loss': 0.3837890625, 'learning_rate': 2.8848884381338743e-05, 'epoch': 9.017374517374517, 'total_flos': 12618874692785664, 'step': 4671}
{'loss': 0.6923828125, 'learning_rate': 2.883367139959432e-05, 'epoch': 9.023166023166024, 'total_flos': 12626986627339776, 'step': 4674}
{'loss': 0.5807291666666666, 'learning_rate': 2.8818458417849898e-05, 'epoch': 9.028957528957529, 'total_flos': 12635098561893888, 'step': 4677}
{'loss': 0.43359375, 'learning_rate': 2.880324543610548e-05, 'epoch': 9.034749034749035, 'total_flos': 12643210496448000, 'step': 4680}
{'loss': 0.7965494791666666, 'learning_rate': 2.8788032454361057e-05, 'epoch': 9.04054054054054, 'total_flos': 12651322431002112, 'step': 4683}
{'loss': 0.4163

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.7861328125, 'learning_rate': 2.6262677484787017e-05, 'epoch': 10.001930501930502, 'total_flos': 13996636077210624, 'step': 5181}
{'loss': 0.5341796875, 'learning_rate': 2.6247464503042595e-05, 'epoch': 10.007722007722007, 'total_flos': 14004748011764736, 'step': 5184}
{'loss': 0.51171875, 'learning_rate': 2.6232251521298173e-05, 'epoch': 10.013513513513514, 'total_flos': 14012859946318848, 'step': 5187}
{'loss': 0.5328776041666666, 'learning_rate': 2.6217038539553757e-05, 'epoch': 10.019305019305019, 'total_flos': 14020971880872960, 'step': 5190}
{'loss': 0.4895833333333333, 'learning_rate': 2.6201825557809335e-05, 'epoch': 10.025096525096526, 'total_flos': 14029083815427072, 'step': 5193}
{'loss': 0.4147135416666667, 'learning_rate': 2.6186612576064912e-05, 'epoch': 10.03088803088803, 'total_flos': 14037195749981184, 'step': 5196}
{'loss': 0.3935546875, 'learning_rate': 2.617139959432049e-05, 'epoch': 10.036679536679536, 'total_flos': 14045307684535296, 'step': 5199}
{'loss

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.3199869791666667, 'learning_rate': 2.363083164300203e-05, 'epoch': 11.003861003861005, 'total_flos': 15398733265297920, 'step': 5700}
{'loss': 0.2819010416666667, 'learning_rate': 2.3615618661257606e-05, 'epoch': 11.00965250965251, 'total_flos': 15406845199852032, 'step': 5703}
{'loss': 0.3678385416666667, 'learning_rate': 2.3600405679513184e-05, 'epoch': 11.015444015444016, 'total_flos': 15414957134406144, 'step': 5706}
{'loss': 0.3229166666666667, 'learning_rate': 2.358519269776876e-05, 'epoch': 11.021235521235521, 'total_flos': 15423069068960256, 'step': 5709}
{'loss': 0.2197265625, 'learning_rate': 2.3569979716024342e-05, 'epoch': 11.027027027027026, 'total_flos': 15431181003514368, 'step': 5712}
{'loss': 0.4091796875, 'learning_rate': 2.355476673427992e-05, 'epoch': 11.032818532818533, 'total_flos': 15439292938068480, 'step': 5715}
{'loss': 0.21126302083333334, 'learning_rate': 2.3539553752535497e-05, 'epoch': 11.038610038610038, 'total_flos': 15447404872622592, 'step':

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.3307291666666667, 'learning_rate': 2.099898580121704e-05, 'epoch': 12.005791505791505, 'total_flos': 16800830453385216, 'step': 6219}
{'loss': 0.3128255208333333, 'learning_rate': 2.0983772819472617e-05, 'epoch': 12.011583011583012, 'total_flos': 16808942387939328, 'step': 6222}
{'loss': 0.3349609375, 'learning_rate': 2.0968559837728198e-05, 'epoch': 12.017374517374517, 'total_flos': 16817054322493440, 'step': 6225}
{'loss': 0.44921875, 'learning_rate': 2.0953346855983775e-05, 'epoch': 12.023166023166024, 'total_flos': 16825166257047552, 'step': 6228}
{'loss': 0.224609375, 'learning_rate': 2.0938133874239353e-05, 'epoch': 12.028957528957529, 'total_flos': 16833278191601664, 'step': 6231}
{'loss': 0.3658854166666667, 'learning_rate': 2.092292089249493e-05, 'epoch': 12.034749034749035, 'total_flos': 16841390126155776, 'step': 6234}
{'loss': 0.556640625, 'learning_rate': 2.0907707910750508e-05, 'epoch': 12.04054054054054, 'total_flos': 16849502060709888, 'step': 6237}
{'loss': 

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.15234375, 'learning_rate': 1.8382352941176472e-05, 'epoch': 13.001930501930502, 'total_flos': 18194815706918400, 'step': 6735}
{'loss': 0.2591145833333333, 'learning_rate': 1.836713995943205e-05, 'epoch': 13.007722007722007, 'total_flos': 18202927641472512, 'step': 6738}
{'loss': 0.3421223958333333, 'learning_rate': 1.8351926977687628e-05, 'epoch': 13.013513513513514, 'total_flos': 18211039576026624, 'step': 6741}
{'loss': 0.318359375, 'learning_rate': 1.8336713995943205e-05, 'epoch': 13.019305019305019, 'total_flos': 18219151510580736, 'step': 6744}
{'loss': 0.2779947916666667, 'learning_rate': 1.8321501014198783e-05, 'epoch': 13.025096525096526, 'total_flos': 18227263445134848, 'step': 6747}
{'loss': 0.10611979166666667, 'learning_rate': 1.830628803245436e-05, 'epoch': 13.03088803088803, 'total_flos': 18235375379688960, 'step': 6750}
{'loss': 0.2119140625, 'learning_rate': 1.829107505070994e-05, 'epoch': 13.036679536679536, 'total_flos': 18243487314243072, 'step': 6753}
{'

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.236328125, 'learning_rate': 1.5750507099391483e-05, 'epoch': 14.003861003861005, 'total_flos': 19596912895005696, 'step': 7254}
{'loss': 0.1640625, 'learning_rate': 1.573529411764706e-05, 'epoch': 14.00965250965251, 'total_flos': 19605024829559808, 'step': 7257}
{'loss': 0.3111979166666667, 'learning_rate': 1.5720081135902635e-05, 'epoch': 14.015444015444016, 'total_flos': 19613136764113920, 'step': 7260}
{'loss': 0.16634114583333334, 'learning_rate': 1.5704868154158216e-05, 'epoch': 14.021235521235521, 'total_flos': 19621248698668032, 'step': 7263}
{'loss': 0.24967447916666666, 'learning_rate': 1.5689655172413794e-05, 'epoch': 14.027027027027026, 'total_flos': 19629360633222144, 'step': 7266}
{'loss': 0.3515625, 'learning_rate': 1.567444219066937e-05, 'epoch': 14.032818532818533, 'total_flos': 19637472567776256, 'step': 7269}
{'loss': 0.16015625, 'learning_rate': 1.565922920892495e-05, 'epoch': 14.038610038610038, 'total_flos': 19645584502330368, 'step': 7272}
{'loss': 0.25

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.197265625, 'learning_rate': 1.311866125760649e-05, 'epoch': 15.005791505791505, 'total_flos': 20999010083092992, 'step': 7773}
{'loss': 0.10188802083333333, 'learning_rate': 1.310344827586207e-05, 'epoch': 15.011583011583012, 'total_flos': 21007122017647104, 'step': 7776}
{'loss': 0.18196614583333334, 'learning_rate': 1.3088235294117648e-05, 'epoch': 15.017374517374517, 'total_flos': 21015233952201216, 'step': 7779}
{'loss': 0.1875, 'learning_rate': 1.3073022312373225e-05, 'epoch': 15.023166023166024, 'total_flos': 21023345886755328, 'step': 7782}
{'loss': 0.13346354166666666, 'learning_rate': 1.3057809330628803e-05, 'epoch': 15.028957528957529, 'total_flos': 21031457821309440, 'step': 7785}
{'loss': 0.10904947916666667, 'learning_rate': 1.304259634888438e-05, 'epoch': 15.034749034749035, 'total_flos': 21039569755863552, 'step': 7788}
{'loss': 0.162109375, 'learning_rate': 1.3027383367139962e-05, 'epoch': 15.04054054054054, 'total_flos': 21047681690417664, 'step': 7791}
{'lo

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.12825520833333334, 'learning_rate': 1.0502028397565924e-05, 'epoch': 16.001930501930502, 'total_flos': 22392995336626176, 'step': 8289}
{'loss': 0.1298828125, 'learning_rate': 1.0486815415821501e-05, 'epoch': 16.00772200772201, 'total_flos': 22401107271180288, 'step': 8292}
{'loss': 0.212890625, 'learning_rate': 1.0471602434077079e-05, 'epoch': 16.013513513513512, 'total_flos': 22409219205734400, 'step': 8295}
{'loss': 0.123046875, 'learning_rate': 1.0456389452332657e-05, 'epoch': 16.01930501930502, 'total_flos': 22417331140288512, 'step': 8298}
{'loss': 0.16861979166666666, 'learning_rate': 1.0441176470588236e-05, 'epoch': 16.025096525096526, 'total_flos': 22425443074842624, 'step': 8301}
{'loss': 0.115234375, 'learning_rate': 1.0425963488843814e-05, 'epoch': 16.030888030888033, 'total_flos': 22433555009396736, 'step': 8304}
{'loss': 0.16569010416666666, 'learning_rate': 1.0410750507099391e-05, 'epoch': 16.036679536679536, 'total_flos': 22441666943950848, 'step': 8307}
{'lo

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.12239583333333333, 'learning_rate': 7.870182555780933e-06, 'epoch': 17.003861003861005, 'total_flos': 23795092524713472, 'step': 8808}
{'loss': 0.1982421875, 'learning_rate': 7.85496957403651e-06, 'epoch': 17.00965250965251, 'total_flos': 23803204459267584, 'step': 8811}
{'loss': 0.2766927083333333, 'learning_rate': 7.83975659229209e-06, 'epoch': 17.015444015444015, 'total_flos': 23811316393821696, 'step': 8814}
{'loss': 0.051432291666666664, 'learning_rate': 7.824543610547668e-06, 'epoch': 17.02123552123552, 'total_flos': 23819428328375808, 'step': 8817}
{'loss': 0.146484375, 'learning_rate': 7.809330628803247e-06, 'epoch': 17.027027027027028, 'total_flos': 23827540262929920, 'step': 8820}
{'loss': 0.15983072916666666, 'learning_rate': 7.794117647058825e-06, 'epoch': 17.03281853281853, 'total_flos': 23835652197484032, 'step': 8823}
{'loss': 0.21549479166666666, 'learning_rate': 7.778904665314402e-06, 'epoch': 17.038610038610038, 'total_flos': 23843764132038144, 'step': 8826

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.15104166666666666, 'learning_rate': 5.238336713995944e-06, 'epoch': 18.005791505791507, 'total_flos': 25197189712800768, 'step': 9327}
{'loss': 0.1318359375, 'learning_rate': 5.2231237322515215e-06, 'epoch': 18.01158301158301, 'total_flos': 25205301647354880, 'step': 9330}
{'loss': 0.11360677083333333, 'learning_rate': 5.207910750507099e-06, 'epoch': 18.017374517374517, 'total_flos': 25213413581908992, 'step': 9333}
{'loss': 0.15299479166666666, 'learning_rate': 5.192697768762678e-06, 'epoch': 18.023166023166024, 'total_flos': 25221525516463104, 'step': 9336}
{'loss': 0.19596354166666666, 'learning_rate': 5.177484787018256e-06, 'epoch': 18.02895752895753, 'total_flos': 25229637451017216, 'step': 9339}
{'loss': 0.15071614583333334, 'learning_rate': 5.1622718052738345e-06, 'epoch': 18.034749034749034, 'total_flos': 25237749385571328, 'step': 9342}
{'loss': 0.12044270833333333, 'learning_rate': 5.147058823529412e-06, 'epoch': 18.04054054054054, 'total_flos': 25245861320125440, 

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=518.0, style=ProgressStyle(description_wi…

{'loss': 0.0703125, 'learning_rate': 2.6217038539553754e-06, 'epoch': 19.001930501930502, 'total_flos': 26591174966333952, 'step': 9843}
{'loss': 0.15006510416666666, 'learning_rate': 2.6064908722109534e-06, 'epoch': 19.00772200772201, 'total_flos': 26599286900888064, 'step': 9846}
{'loss': 0.18717447916666666, 'learning_rate': 2.591277890466532e-06, 'epoch': 19.013513513513512, 'total_flos': 26607398835442176, 'step': 9849}
{'loss': 0.17805989583333334, 'learning_rate': 2.5760649087221095e-06, 'epoch': 19.01930501930502, 'total_flos': 26615510769996288, 'step': 9852}
{'loss': 0.18912760416666666, 'learning_rate': 2.560851926977688e-06, 'epoch': 19.025096525096526, 'total_flos': 26623622704550400, 'step': 9855}
{'loss': 0.15657552083333334, 'learning_rate': 2.5456389452332656e-06, 'epoch': 19.030888030888033, 'total_flos': 26631734639104512, 'step': 9858}
{'loss': 0.3447265625, 'learning_rate': 2.530425963488844e-06, 'epoch': 19.036679536679536, 'total_flos': 26639846573658624, 'step':

TrainOutput(global_step=10360, training_loss=1.0737998461631275)

In [28]:
trainer.evaluate()

HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=130.0, style=ProgressStyle(description_w…


{'eval_loss': 2.2890994422101745, 'eval_accuracy': 0.6442028985507247, 'eval_f1': 0.6442028985507247, 'eval_precision': 0.6442028985507247, 'eval_recall': 0.6442028985507247, 'epoch': 20.0, 'total_flos': 27987864198051840, 'step': 10360}


{'epoch': 20.0,
 'eval_accuracy': 0.6442028985507247,
 'eval_f1': 0.6442028985507247,
 'eval_loss': 2.2890994422101745,
 'eval_precision': 0.6442028985507247,
 'eval_recall': 0.6442028985507247,
 'total_flos': 27987864198051840}

In [29]:
predictions, label_ids, metrics = trainer.predict(test_ds)
for key, value in metrics.items():
    print( key, value)

HBox(children=(FloatProgress(value=0.0, description='Prediction', max=130.0, style=ProgressStyle(description_w…


eval_loss 2.2890994422101745
eval_accuracy 0.6442028985507247
eval_f1 0.6442028985507247
eval_precision 0.6442028985507247
eval_recall 0.6442028985507247


In [30]:
inputs = tokenizer("Da-Lite Stand Master I - Cart for projector Projection Carts - Stand Master I Features The height of both the upper and lower shelves", return_tensors="pt")
print(inputs)
labels = torch.tensor([76]).unsqueeze(0)
print(labels)

{'input_ids': tensor([[  101,  4830, 29624, 22779,  3233,  3040,  1045,  1011, 11122,  2005,
          2622,  2953, 13996, 25568,  1011,  3233,  3040,  1045,  2838,  1996,
          4578,  1997,  2119,  1996,  3356,  1998,  2896, 15475,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1]])}
tensor([[76]])


In [31]:
model.to('cpu')
outputs = model(**inputs, labels=labels)
print(outputs.loss)
pred=outputs.logits.argmax(-1)
print('prediction=',pred,idx2label[(int)(pred.cpu().detach().numpy())])

tensor(0.7203, grad_fn=<NllLossBackward>)
prediction= tensor([76]) Projection Screens
