In [1]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import pandas as pd
from transformers import BertTokenizerFast, BertForTokenClassification
import torch.utils.data as data
import torch
from tqdm import tqdm

In [3]:
url = "https://drive.google.com/file/d/1cxg7dKxIBDtaqn9kD4b3XeboTwOKADlJ/view?usp=sharing"
url='https://drive.google.com/uc?id=' + url.split('/')[-2]
df = pd.read_csv(url)
sent = df["text"].values[0]
print(sent)

Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .


In [4]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
text_tokenized = tokenizer(sent, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
print(text_tokenized)

{'input_ids': tensor([[  101, 26159,  1104,  8568,  4487,  5067,  1138,  9639,  1194,  1498,
          1106,  5641,  1103,  1594,  1107,  5008,  1105,  4555,  1103, 10602,
          1104,  1418,  2830,  1121,  1115,  1583,   119,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

In [5]:
tokenizer.decode(text_tokenized["input_ids"][0])

'[CLS] Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [6]:
tokens = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0])
word_ids = text_tokenized.word_ids()
print(tokens)
print(word_ids)
### sub-word tokens + special tokens
### corresponding word idx in the originial sentence

['[CLS]', 'Thousands', 'of', 'demons', '##tra', '##tors', 'have', 'marched', 'through', 'London', 'to', 'protest', 'the', 'war', 'in', 'Iraq', 'and', 'demand', 'the', 'withdrawal', 'of', 'British', 'troops', 'from', 'that', 'country', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PA

In [7]:
def get_data_label_details(df):
    
    labels = [label.split() for label in df["labels"].values.tolist()]
    unique_labels = set()
    for sent_label in labels:
        [unique_labels.add(token_lb) for token_lb in sent_label]
    
    num_unique_labels = len(unique_labels)
    print(f"Number of Unique Labels: {num_unique_labels}")
    label_to_idx = { label : idx for idx, label in enumerate(sorted(unique_labels))}
    idx_to_label = { idx : label for idx, label in enumerate(sorted(unique_labels))}
    return num_unique_labels, label_to_idx, idx_to_label

def align_label(tokenized_sent, labels, label_to_idx):
    word_ids = tokenized_sent.word_ids()
    previous_word_idx = None
    label_ids = []
    
    for word_idx in word_ids:
        if word_idx is None: ### special token
            label_ids.append(-100)
        elif word_idx != previous_word_idx: ### new token
            try:
                label_ids.append(label_to_idx[labels[word_idx]])
            except:
                ### not in vocabulary
                label_ids.append(-100)
        else: ### repeated token from the same word
            label_ids.append(-100)
        
        previous_word_idx = word_idx
        
    return label_ids
            
    

In [8]:
class NERDataset(data.Dataset):
    
    def __init__(self, filepath, tokenizer, start=None, end=None):
        df = pd.read_csv(filepath)
        self.num_unique_labels, self.label_to_idx, self.idx_to_label = get_data_label_details(df)
        if(start is not None and end is not None):
            df = df[start:end]
        labels = [label.split() for label in df['labels'].values.tolist()] ### list of lists where each list of the NER labels
        sentences = df["text"].values.tolist() ### list of sentences
        self.txt = [tokenizer(sent, padding="max_length", max_length=512, truncation=True, return_tensors="pt") for sent in sentences]
        self.labels = [align_label(sent, label, label_to_idx=self.label_to_idx) for sent, label in zip(self.txt, labels)]
        self.len = len(self.labels)
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, index):
        return self.txt[index], torch.LongTensor(self.labels[index])

In [9]:
class BertModel(torch.nn.Module):

    def __init__(self, num_unique_labels):

        super(BertModel, self).__init__()

        self.bert = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=num_unique_labels)

    def forward(self, input_id, mask, label):

        output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)

        return output

In [10]:
def trainer(model):
    
    torch.cuda.empty_cache()

    tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
    trainDataset = NERDataset(filepath= url, tokenizer=tokenizer, start=0, end=900)
    validationDataset = NERDataset(filepath= url, tokenizer=tokenizer, start=900, end=1000)
    trainDataLoader = data.DataLoader(trainDataset, batch_size=8, shuffle=True)
    valDataLoader = data.DataLoader(validationDataset, batch_size=8)
    
    cuda_available = torch.cuda.is_available() 
    device = torch.device("cuda" if  cuda_available else "cpu")
    optimizer = torch.optim.SGD(model.parameters(), lr=3e-4)
    
    if cuda_available:
        model.cuda()
    
    for epoch in range(1):
        
        train_acc, val_acc = 0, 0
        model.train()
        for train_data, train_labels in tqdm(trainDataLoader):
            train_labels = train_labels.to(device)
            input_ids = train_data["input_ids"].squeeze(1).to(device)
            attn_masks = train_data["attention_mask"].squeeze(1).to(device)
            
            optimizer.zero_grad()
            loss, logits = model(input_ids, attn_masks, train_labels)
            ### LOGITS => BATCH_SIZE * SEQ_LEN * NUM_LABELS
            ## LABELS => BATCH_SIZE * SEQ_LEN
            for idx in range(logits.shape[0]): ### iterate over all samples
                ### remove redundant tokens for accuracy computation
                clean_logits = logits[idx][train_labels[idx] != -100] ## SEQ_LEN' * NUM_LABELS
                clean_labels = train_labels[idx][train_labels[idx] != -100 ] ## SEQ_LEN'
                pred = clean_logits.argmax(dim=1) ### SEQ_LEN'
                train_acc += (pred == clean_labels).float().mean()
                
            loss.backward()
            optimizer.step()
            print(f"Train Loss: {loss.item()}", end="\r")
        
        ### EVALUATION ON VALIDATION SET
        model.eval()
        for val_data, val_labels in tqdm(valDataLoader):
            val_labels = val_labels.to(device)
            input_ids = val_data["input_ids"].squeeze(1).to(device)
            attn_masks = val_data["attention_mask"].squeeze(1).to(device)
            
            loss, logits = model(input_ids, attn_masks, val_labels)
            
            for idx in range(logits.shape[0]): ### iterate over all samples
                ### remove redundant tokens for accuracy computation
                clean_logits = logits[idx][val_labels[idx] != -100] ## SEQ_LEN' * NUM_LABELS
                clean_labels = val_labels[idx][val_labels[idx] != -100 ] ## SEQ_LEN'
                pred = clean_logits.argmax(dim=1) ### SEQ_LEN'
                val_acc += (pred == clean_labels).float().mean()
            print(f"Val Loss: {loss.item()}", end="\r")
        
        print(f"Epoch: {epoch+1} | Train Acc: {train_acc/len(trainDataset)} | Val Acc: {val_acc/len(validationDataset)}")
            

            
    
    

In [11]:
model = BertModel(num_unique_labels=17)
trainer(model)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

Number of Unique Labels: 17
Number of Unique Labels: 17


  1%|          | 1/113 [00:01<02:57,  1.59s/it]

Train Loss: 2.8463356494903564

  2%|▏         | 2/113 [00:02<01:58,  1.07s/it]

Train Loss: 2.788968563079834

  3%|▎         | 3/113 [00:02<01:38,  1.12it/s]

Train Loss: 2.7465460300445557

  4%|▎         | 4/113 [00:03<01:28,  1.23it/s]

Train Loss: 2.731125593185425

  4%|▍         | 5/113 [00:04<01:22,  1.30it/s]

Train Loss: 2.6999924182891846

  5%|▌         | 6/113 [00:05<01:19,  1.35it/s]

Train Loss: 2.7598698139190674

  6%|▌         | 7/113 [00:05<01:16,  1.38it/s]

Train Loss: 2.6812663078308105

  7%|▋         | 8/113 [00:06<01:15,  1.40it/s]

Train Loss: 2.553077459335327

  8%|▊         | 9/113 [00:07<01:13,  1.41it/s]

Train Loss: 2.6066970825195312

  9%|▉         | 10/113 [00:07<01:12,  1.41it/s]

Train Loss: 2.5945498943328857

 10%|▉         | 11/113 [00:08<01:11,  1.42it/s]

Train Loss: 2.581069231033325

 11%|█         | 12/113 [00:09<01:11,  1.42it/s]

Train Loss: 2.5592598915100098

 12%|█▏        | 13/113 [00:09<01:10,  1.42it/s]

Train Loss: 2.493511438369751

 12%|█▏        | 14/113 [00:10<01:11,  1.39it/s]

Train Loss: 2.5110700130462646

 13%|█▎        | 15/113 [00:11<01:11,  1.37it/s]

Train Loss: 2.428541421890259

 14%|█▍        | 16/113 [00:12<01:10,  1.38it/s]

Train Loss: 2.4319424629211426

 15%|█▌        | 17/113 [00:12<01:08,  1.39it/s]

Train Loss: 2.3927838802337646

 16%|█▌        | 18/113 [00:13<01:07,  1.40it/s]

Train Loss: 2.2993080615997314

 17%|█▋        | 19/113 [00:14<01:06,  1.41it/s]

Train Loss: 2.364828109741211

 18%|█▊        | 20/113 [00:14<01:05,  1.41it/s]

Train Loss: 2.280402421951294

 19%|█▊        | 21/113 [00:15<01:05,  1.41it/s]

Train Loss: 2.3596718311309814

 19%|█▉        | 22/113 [00:16<01:04,  1.41it/s]

Train Loss: 2.3540115356445312

 20%|██        | 23/113 [00:17<01:03,  1.41it/s]

Train Loss: 2.31443452835083

 21%|██        | 24/113 [00:17<01:03,  1.41it/s]

Train Loss: 2.2864811420440674

 22%|██▏       | 25/113 [00:18<01:02,  1.41it/s]

Train Loss: 2.215543270111084

 23%|██▎       | 26/113 [00:19<01:01,  1.41it/s]

Train Loss: 2.199955940246582

 24%|██▍       | 27/113 [00:19<01:01,  1.40it/s]

Train Loss: 2.2381386756896973

 25%|██▍       | 28/113 [00:20<01:00,  1.41it/s]

Train Loss: 2.0571224689483643

 26%|██▌       | 29/113 [00:21<00:59,  1.40it/s]

Train Loss: 2.0178606510162354

 27%|██▋       | 30/113 [00:22<00:59,  1.40it/s]

Train Loss: 2.1157009601593018

 27%|██▋       | 31/113 [00:22<00:58,  1.40it/s]

Train Loss: 2.1204962730407715

 28%|██▊       | 32/113 [00:23<00:58,  1.39it/s]

Train Loss: 2.0583786964416504

 29%|██▉       | 33/113 [00:24<00:57,  1.39it/s]

Train Loss: 2.0164177417755127

 30%|███       | 34/113 [00:24<00:57,  1.38it/s]

Train Loss: 2.01900053024292

 31%|███       | 35/113 [00:25<00:56,  1.38it/s]

Train Loss: 1.847648024559021

 32%|███▏      | 36/113 [00:26<00:55,  1.38it/s]

Train Loss: 1.936570167541504

 33%|███▎      | 37/113 [00:27<00:55,  1.38it/s]

Train Loss: 1.9842159748077393

 34%|███▎      | 38/113 [00:27<00:54,  1.37it/s]

Train Loss: 1.937506079673767

 35%|███▍      | 39/113 [00:28<00:53,  1.37it/s]

Train Loss: 1.9005013704299927

 35%|███▌      | 40/113 [00:29<00:53,  1.37it/s]

Train Loss: 1.8304986953735352

 36%|███▋      | 41/113 [00:30<00:52,  1.37it/s]

Train Loss: 1.7771021127700806

 37%|███▋      | 42/113 [00:30<00:51,  1.37it/s]

Train Loss: 1.9704777002334595

 38%|███▊      | 43/113 [00:31<00:51,  1.37it/s]

Train Loss: 1.8122522830963135

 39%|███▉      | 44/113 [00:32<00:50,  1.37it/s]

Train Loss: 1.8488050699234009

 40%|███▉      | 45/113 [00:33<00:49,  1.36it/s]

Train Loss: 1.721806526184082

 41%|████      | 46/113 [00:33<00:49,  1.36it/s]

Train Loss: 1.8222670555114746

 42%|████▏     | 47/113 [00:34<00:48,  1.36it/s]

Train Loss: 1.6940425634384155

 42%|████▏     | 48/113 [00:35<00:47,  1.36it/s]

Train Loss: 1.8381980657577515

 43%|████▎     | 49/113 [00:35<00:47,  1.35it/s]

Train Loss: 1.6353451013565063

 44%|████▍     | 50/113 [00:36<00:46,  1.35it/s]

Train Loss: 1.6479272842407227

 45%|████▌     | 51/113 [00:37<00:45,  1.35it/s]

Train Loss: 1.602799892425537

 46%|████▌     | 52/113 [00:38<00:45,  1.35it/s]

Train Loss: 1.6200693845748901

 47%|████▋     | 53/113 [00:38<00:44,  1.34it/s]

Train Loss: 1.5907399654388428

 48%|████▊     | 54/113 [00:39<00:43,  1.34it/s]

Train Loss: 1.6638154983520508

 49%|████▊     | 55/113 [00:40<00:43,  1.34it/s]

Train Loss: 1.5564125776290894

 50%|████▉     | 56/113 [00:41<00:42,  1.35it/s]

Train Loss: 1.6122803688049316

 50%|█████     | 57/113 [00:41<00:41,  1.35it/s]

Train Loss: 1.5381038188934326

 51%|█████▏    | 58/113 [00:42<00:40,  1.35it/s]

Train Loss: 1.5649384260177612

 52%|█████▏    | 59/113 [00:43<00:40,  1.35it/s]

Train Loss: 1.4519399404525757

 53%|█████▎    | 60/113 [00:44<00:39,  1.35it/s]

Train Loss: 1.449820876121521

 54%|█████▍    | 61/113 [00:44<00:38,  1.35it/s]

Train Loss: 1.5218822956085205

 55%|█████▍    | 62/113 [00:45<00:37,  1.35it/s]

Train Loss: 1.3204002380371094

 56%|█████▌    | 63/113 [00:46<00:36,  1.35it/s]

Train Loss: 1.3483514785766602

 57%|█████▋    | 64/113 [00:47<00:36,  1.36it/s]

Train Loss: 1.4551112651824951

 58%|█████▊    | 65/113 [00:47<00:35,  1.36it/s]

Train Loss: 1.477880835533142

 58%|█████▊    | 66/113 [00:48<00:34,  1.36it/s]

Train Loss: 1.2677571773529053

 59%|█████▉    | 67/113 [00:49<00:33,  1.36it/s]

Train Loss: 1.3225606679916382

 60%|██████    | 68/113 [00:50<00:32,  1.37it/s]

Train Loss: 1.355446696281433

 61%|██████    | 69/113 [00:50<00:32,  1.37it/s]

Train Loss: 1.2576168775558472

 62%|██████▏   | 70/113 [00:51<00:31,  1.37it/s]

Train Loss: 1.3051445484161377

 63%|██████▎   | 71/113 [00:52<00:30,  1.37it/s]

Train Loss: 1.3845844268798828

 64%|██████▎   | 72/113 [00:52<00:29,  1.37it/s]

Train Loss: 1.4351516962051392

 65%|██████▍   | 73/113 [00:53<00:29,  1.38it/s]

Train Loss: 1.1989777088165283

 65%|██████▌   | 74/113 [00:54<00:28,  1.38it/s]

Train Loss: 1.4758963584899902

 66%|██████▋   | 75/113 [00:55<00:27,  1.38it/s]

Train Loss: 1.2200794219970703

 67%|██████▋   | 76/113 [00:55<00:26,  1.39it/s]

Train Loss: 1.2862837314605713

 68%|██████▊   | 77/113 [00:56<00:25,  1.39it/s]

Train Loss: 1.4166233539581299

 69%|██████▉   | 78/113 [00:57<00:25,  1.39it/s]

Train Loss: 1.4743211269378662

 70%|██████▉   | 79/113 [00:57<00:24,  1.39it/s]

Train Loss: 1.2223385572433472

 71%|███████   | 80/113 [00:58<00:23,  1.39it/s]

Train Loss: 1.2672268152236938

 72%|███████▏  | 81/113 [00:59<00:22,  1.39it/s]

Train Loss: 1.0807915925979614

 73%|███████▎  | 82/113 [01:00<00:22,  1.39it/s]

Train Loss: 1.3923779726028442

 73%|███████▎  | 83/113 [01:00<00:21,  1.40it/s]

Train Loss: 1.0612423419952393

 74%|███████▍  | 84/113 [01:01<00:20,  1.40it/s]

Train Loss: 1.1975762844085693

 75%|███████▌  | 85/113 [01:02<00:19,  1.40it/s]

Train Loss: 1.1627355813980103

 76%|███████▌  | 86/113 [01:02<00:19,  1.40it/s]

Train Loss: 1.2780336141586304

 77%|███████▋  | 87/113 [01:03<00:18,  1.40it/s]

Train Loss: 1.1771290302276611

 78%|███████▊  | 88/113 [01:04<00:17,  1.40it/s]

Train Loss: 1.1596755981445312

 79%|███████▉  | 89/113 [01:05<00:17,  1.40it/s]

Train Loss: 1.0284290313720703

 80%|███████▉  | 90/113 [01:05<00:16,  1.40it/s]

Train Loss: 1.157707929611206

 81%|████████  | 91/113 [01:06<00:15,  1.40it/s]

Train Loss: 1.0685611963272095

 81%|████████▏ | 92/113 [01:07<00:14,  1.40it/s]

Train Loss: 0.9150493741035461

 82%|████████▏ | 93/113 [01:07<00:14,  1.41it/s]

Train Loss: 1.2008963823318481

 83%|████████▎ | 94/113 [01:08<00:13,  1.41it/s]

Train Loss: 1.115121603012085

 84%|████████▍ | 95/113 [01:09<00:12,  1.41it/s]

Train Loss: 1.2155985832214355

 85%|████████▍ | 96/113 [01:10<00:12,  1.41it/s]

Train Loss: 0.9494081139564514

 86%|████████▌ | 97/113 [01:10<00:11,  1.41it/s]

Train Loss: 0.9408044219017029

 87%|████████▋ | 98/113 [01:11<00:10,  1.41it/s]

Train Loss: 1.2637041807174683

 88%|████████▊ | 99/113 [01:12<00:09,  1.41it/s]

Train Loss: 1.0685381889343262

 88%|████████▊ | 100/113 [01:12<00:09,  1.41it/s]

Train Loss: 1.237686276435852

 89%|████████▉ | 101/113 [01:13<00:08,  1.41it/s]

Train Loss: 1.1897286176681519

 90%|█████████ | 102/113 [01:14<00:07,  1.42it/s]

Train Loss: 0.9815884232521057

 91%|█████████ | 103/113 [01:15<00:07,  1.42it/s]

Train Loss: 1.0009734630584717

 92%|█████████▏| 104/113 [01:15<00:06,  1.42it/s]

Train Loss: 0.8060289025306702

 93%|█████████▎| 105/113 [01:16<00:05,  1.42it/s]

Train Loss: 0.9121584892272949

 94%|█████████▍| 106/113 [01:17<00:04,  1.41it/s]

Train Loss: 1.0746768712997437

 95%|█████████▍| 107/113 [01:17<00:04,  1.41it/s]

Train Loss: 1.2228753566741943

 96%|█████████▌| 108/113 [01:18<00:03,  1.41it/s]

Train Loss: 0.9460062980651855

 96%|█████████▋| 109/113 [01:19<00:02,  1.41it/s]

Train Loss: 0.9300405383110046

 97%|█████████▋| 110/113 [01:19<00:02,  1.42it/s]

Train Loss: 1.2049089670181274

 98%|█████████▊| 111/113 [01:20<00:01,  1.42it/s]

Train Loss: 0.848455011844635

 99%|█████████▉| 112/113 [01:21<00:00,  1.42it/s]

Train Loss: 0.7971936464309692

100%|██████████| 113/113 [01:21<00:00,  1.38it/s]


Train Loss: 1.2502652406692505

  8%|▊         | 1/13 [00:00<00:03,  3.48it/s]

Val Loss: 1.0571277141571045

 15%|█▌        | 2/13 [00:00<00:03,  3.67it/s]

Val Loss: 0.8070427775382996

 23%|██▎       | 3/13 [00:00<00:02,  3.68it/s]

Val Loss: 0.8700703978538513

 31%|███       | 4/13 [00:01<00:02,  3.73it/s]

Val Loss: 1.1048928499221802

 38%|███▊      | 5/13 [00:01<00:02,  3.72it/s]

Val Loss: 0.9100430011749268

 46%|████▌     | 6/13 [00:01<00:01,  3.75it/s]

Val Loss: 0.9745497703552246

 54%|█████▍    | 7/13 [00:01<00:01,  3.74it/s]

Val Loss: 1.2005308866500854

 62%|██████▏   | 8/13 [00:02<00:01,  3.76it/s]

Val Loss: 1.0058170557022095

 69%|██████▉   | 9/13 [00:02<00:01,  3.75it/s]

Val Loss: 0.8278400301933289

 77%|███████▋  | 10/13 [00:02<00:00,  3.76it/s]

Val Loss: 0.7783141136169434

 85%|████████▍ | 11/13 [00:02<00:00,  3.76it/s]

Val Loss: 1.0373300313949585

100%|██████████| 13/13 [00:03<00:00,  3.87it/s]

Val Loss: 1.1908291578292847Val Loss: 0.888100266456604Epoch: 1 | Train Acc: 0.7372771501541138 | Val Acc: 0.8369261026382446





In [16]:
torch.cuda.empty_cache()

In [18]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 2            |        cudaMalloc retries: 2         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  16640 KiB |  14190 MiB |   2218 GiB |   2218 GiB |
|       from large pool |  16640 KiB |  14185 MiB |   2214 GiB |   2214 GiB |
|       from small pool |      0 KiB |      6 MiB |      4 GiB |      4 GiB |
|---------------------------------------------------------------------------|
| Active memory         |  16640 KiB |  14190 MiB |   2218 GiB |   2218 GiB |
|       from large pool |  16640 KiB |  14185 MiB |   2214 GiB |   2214 GiB |
|       from small pool |      0 KiB |      6 MiB |      4 GiB |      4 GiB |
|---------------------------------------------------------------