# A simple tutorial for QA correlation prediction with pre-trained language model

In [1]:
# Róisín Luo, Colm O'Riordan (supervisor)

In [2]:
import sys
import os
import random
import math

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from tqdm import tqdm

import csv
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline

# GPU acceleration just in case

In [3]:
def get_hwacc_device_v3():

    device = torch.device('cpu')
    
    if torch.cuda.is_available():
        
        print(torch.cuda.get_device_name(0))
        print('CUDA memory Usage:')
        print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
        print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
    
        device = torch.device('cuda')
    # MacOS
    elif hasattr(torch, "backends") and \
          hasattr(torch.backends, "mps") and \
          torch.backends.mps.is_available():
                
        device = torch.device('mps')
 
    print("GPU device is: ", device)
    
    return device

In [4]:
device = get_hwacc_device_v3()
#device = torch.device("cpu")
device

GPU device is:  mps


device(type='mps')

# Loading math QA dataset

In [5]:
from datasets import list_datasets

datasets_list = list_datasets()
len(datasets_list)

36541

In [7]:
for ds in datasets_list:
    if "wiki_qa" in ds:
        print(ds)

iapp_wiki_qa_squad
wiki_qa
wiki_qa_ar
wannaphong/iapp_wiki_qa_squad_oa
sedthh/cmu_wiki_qa
michaelthwan/wiki_qa_bart_1000row
michaelthwan/wiki_qa_bart_10000row
michaelthwan/oa_wiki_qa_bart_10000row


In [6]:
from datasets import load_dataset
dataset = load_dataset(path = "wiki_qa",
                       cache_dir = ".." + os.sep + ".." + os.sep + "Dataset_Collection", 
                       download_mode = "reuse_dataset_if_exists")

Found cached dataset wiki_qa (/Users/roisinjiaolinluo/Documents/Research/AI_Research/Roisins_Tutorials_on_DL_Applications/QA_correlation_prediction_with_pre-trained_BERT-GPT/../../Dataset_Collection/wiki_qa/default/0.1.0/d2d236b5cbdc6fbdab45d168b4d678a002e06ddea3525733a24558150585951c)


  0%|          | 0/3 [00:00<?, ?it/s]

## Investigating dataset shape

In [7]:
dataset

DatasetDict({
    test: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 6165
    })
    validation: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 2733
    })
    train: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 20360
    })
})

In [8]:
dataset['train'][0]

{'question_id': 'Q1',
 'question': 'how are glacier caves formed?',
 'document_title': 'Glacier cave',
 'answer': 'A partly submerged glacier cave on Perito Moreno Glacier .',
 'label': 0}

In [9]:
len(dataset['train'])

20360

In [10]:
print(dataset['train'][0]['question'])

how are glacier caves formed?


In [11]:
print(dataset['train'][0]['answer'])

A partly submerged glacier cave on Perito Moreno Glacier .


## Setting dataset and splitting dataset

In [12]:
#Setting format to torch or tensorflow
dataset.set_format(type='torch', columns=['question', 'answer'])

In [13]:
dataset_train_ = dataset['train']
dataset_val_ = dataset['validation']
dataset_test_ = dataset['test']

# Wrapping the dataset with label

In [14]:
class QAPredictionDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 dataset, 
                 paired_sampling_prob = 0.5, 
                 random_seed = 42):
        self.dataset = dataset
        self.paired_sampling_prob = paired_sampling_prob
        self.dataset_size = len(dataset)
        
        self.dataset_indices = list(np.arange(0, self.dataset_size))
        
        np.random.seed(random_seed)
        
    def __getitem__(self, index):
        q = self.dataset[index]['question']
        a = self.dataset[index]['answer']
        
        MAX_LEN=512
        #Truncate Q and A to MAX_LEN
        #TODO.
        
        if np.random.rand() < self.paired_sampling_prob:
            y = 1
        else:
            y = 0
            #Resampling a 'answer'
            new_index = int(np.random.choice(self.dataset_indices))
            a = self.dataset[new_index]['answer']
        
        return (q, a), y
    
    def __len__(self):
        return self.dataset_size

In [15]:
dataset_train = QAPredictionDataset(dataset_train_, paired_sampling_prob = 0.5)
dataset_val = QAPredictionDataset(dataset_val_, paired_sampling_prob = 0.5)
dataset_test = QAPredictionDataset(dataset_test_, paired_sampling_prob = 0.5)

In [16]:
for i in range(0, 3):
    (q,a), y = dataset_train[i]
    print("Q: ", q)
    print("A: ", a)
    print("Paired prob: ", y)
    print()

Q:  how are glacier caves formed?
A:  A partly submerged glacier cave on Perito Moreno Glacier .
Paired prob:  1

Q:  how are glacier caves formed?
A:  In modern politics, the most high profile political campaigns are focused on candidates for head of state or head of government , often a President or Prime Minister .
Paired prob:  0

Q:  how are glacier caves formed?
A:  Recovery was an international success and was named the best selling album of 2010 worldwide, joining The Eminem Show, which was the best seller of 2002.
Paired prob:  0



In [17]:
batch_size=8

In [18]:
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

In [19]:
batch = next(iter(dataloader_train))
len(batch)

2

In [20]:
(q,a),y = batch

In [21]:
len(q)

8

In [22]:
len(a)

8

In [23]:
print(q[0])

how many countries are member of the eu?


In [24]:
print(a[0])

The EU was the recipient of the 2012 Nobel Peace Prize .


In [25]:
print(y[0])

tensor(1)


# Loading pre-trained model

In [26]:
from transformers import GPT2Tokenizer, GPT2Model

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt = GPT2Model.from_pretrained('gpt2')

## Testing embeddings

In [27]:
#a batch of text
text = q
print(text)

('how many countries are member of the eu?', 'where are poison dart frog seen', 'how many missions has the us sent to mars', 'what is vat tax?', 'when does college football training camp start', 'who wrote the song a little more country than that>', 'how many countries have english as an official language', 'who was in great britain before the anglo-saxons ?')


In [31]:
# Tokenized input
inputs = tokenizer(text[0], return_tensors='pt')

#sending to GPU
inputs = inputs.to(device)

print(inputs)

{'input_ids': tensor([[4919,  867, 2678,  389, 2888,  286,  262,  304,   84,   30]],
       device='mps:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='mps:0')}


In [32]:
#sending model to GPU if possible
gpt.to(device)

#get embeddings.
with torch.no_grad():
    outputs = gpt(**inputs)

In [33]:
last_hidden_state = outputs.last_hidden_state

In [34]:
last_hidden_state.shape

torch.Size([1, 10, 768])

In [35]:
#Unlike BERT we use [CLS] as context.
#In GPT, we simply add all token representations.

context = torch.sum(last_hidden_state, dim = 1)

In [36]:
context.shape

torch.Size([1, 768])

In [37]:
#We use this position as embedding.
context[0].shape

torch.Size([768])

# Building correlation prediction model

In [54]:
from transformers import GPT2Tokenizer, GPT2Model




class QACorrelationPredictionModel(nn.Module):
    def __init__(self, device = torch.device("cpu")):
        
        super().__init__()
        
        self.device = device
  
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.gpt = GPT2Model.from_pretrained('gpt2')


        #correlation prediction head.
        self.pred_head = nn.Sequential(
                    #nn.Dropout(p = 0.1),
                    
                    #Because we add all token representations as context.
                    #So we need to re-center the representations statistically.
                    #
                    #BUT for NLP task, we DONT use batchnorm. instead SHOULD use layernorm.
                    nn.LayerNorm(normalized_shape = (768*2)), 
                    nn.Linear(in_features = 768*2, out_features = 100, bias = True),
                    nn.ReLU(),
            
                    nn.Linear(in_features = 100, out_features = 1, bias = True),
                    nn.Sigmoid(),
                    )
        #self.pred_head.to(device)
    
    def forward(self, batch):
        self.gpt.eval()
        
        #A batch consists of ((q,a), y), in which
        #q and q are lists, y an array.

        #get embeddings
        qa_embed = []
        
        (q_list, a_list), y = batch
        
        for q, a in zip(q_list, a_list):
            
            #print()
            #print("Q: ", q)
            #print("A: ", a)
            
            # Tokenized Q and A
            inputs_q = self.tokenizer(q, return_tensors='pt')
            inputs_a = self.tokenizer(a, return_tensors='pt')
            
            inputs_q = inputs_q.to(self.device)
            inputs_a = inputs_a.to(self.device)
            
            #get context vectors.
            with torch.no_grad():
                outputs_q = self.gpt(**inputs_q)
                outputs_a = self.gpt(**inputs_q)
      
            last_hidden_state_q = outputs_q.last_hidden_state
            last_hidden_state_a = outputs_a.last_hidden_state
            #Unlike BERT we use [CLS] as context.
            #In GPT, we simply add all token representations.

            context_q = torch.sum(last_hidden_state_q, dim = 1)
            context_a = torch.sum(last_hidden_state_a, dim = 1)
            
            #print("last_hidden_state_q.shape = ", last_hidden_state_q.shape)
            context_q = context_q.squeeze(0)
            context_a = context_a.squeeze(0)
        
            #print()
            #print("q_embed.shape = ", q_embed.shape)
            #print("a_embed.shape = ", a_embed.shape)
        
            #Concatenate two 768 into 768*2
            embed = torch.cat([context_q, context_a])
            #print("embed.shape = ", embed.shape)
            
            qa_embed.append(embed)
            
        qa_embed = torch.stack(qa_embed)
        #print("qa_embed.shape = ", qa_embed.shape)
        
        #stop gradients.
        qa_embed = qa_embed.detach()
        
        probs = self.pred_head(qa_embed)
        probs = probs.squeeze(1)
        
        return probs

In [55]:
batch = next(iter(dataloader_train))

In [56]:
(q,a),labels = batch
labels

tensor([1, 0, 1, 1, 1, 1, 0, 0])

In [57]:
#testing the model simply.
model = QACorrelationPredictionModel(device = device)
model.to(device)
batch = next(iter(dataloader_train))
probs = model(batch)

In [58]:
probs.shape

torch.Size([8])

In [59]:
probs

tensor([0.4834, 0.4755, 0.4743, 0.4775, 0.4827, 0.4804, 0.4751, 0.4844],
       device='mps:0', grad_fn=<SqueezeBackward1>)

In [60]:
probs >= 0.5

tensor([False, False, False, False, False, False, False, False],
       device='mps:0')

In [61]:
torch.mean(probs >= 0.5, dtype = torch.float).cpu()

tensor(0.)

In [62]:
(_, _), labels = batch
labels.shape

torch.Size([8])

In [63]:
labels.shape

torch.Size([8])

In [64]:
labels

tensor([1, 0, 0, 0, 1, 1, 1, 1])

In [65]:
criterion = torch.nn.BCELoss()

#sending to GPU.
probs = probs.to(device)
labels = labels.float().to(device)

criterion(probs, labels)

tensor(0.6994, device='mps:0', grad_fn=<BinaryCrossEntropyBackward0>)

# Training one epoch

In [66]:
%matplotlib inline

#from IPython.display import display, clear_output
from IPython import display

def train_one_epoch(
          model, 
          device, 
          dataloader, 
          optimizer, 
          criterion,
          epoch,
          max_batches = None):
    
    # Enable gradient computing
    model.to(device)
    model.train()
    
    if max_batches is None:
        max_batches = len(dataloader)
    
    #some statistics
    
    #averaged loss in current epoch.
    epoch_loss = 0.0
    total_loss = 0.0
    
    #accuracy in current epoch
    batch_accuracy = 0.0
    #accuracy in current batch
    epoch_accuracy = 0.0
    
    #how many samples predicted correct.
    epoch_corrects = 0.0
    #how many samples trained in this epoch
    epoch_total = 0.0
    
    for batch_idx, batch in enumerate(dataloader, 1):
        
        (q,a), labels_ = batch
        
        labels = labels_.float()
        #sending labels to GPU if possible
        labels = labels.to(device)

        optimizer.zero_grad()
        
        #predictions.
        preds = model(batch)
        
        #computing BCE
        loss = criterion(preds, labels)
           
        #computing gradients
        loss.backward()
        
        #optimizing the classifier, Notice: the GPT is fixed.
        optimizer.step()
        
        
        #computing accuracy in a batch
        # torch.max() returns values, indices
        preds_ = (preds >= 0.5).int().cpu().data
        #batch_accuracy = torch.mean(preds > 0.5, dtype = torch.float).detach().cpu()
        #batch_accuracy = torch.mean(preds_.float()).detach().cpu()
        
        #computing the total loss and average loss in one epoch
        total_loss += loss.detach().cpu().numpy()
        epoch_loss = total_loss / batch_idx
        
        #computing the correct and total samples
        batch_corrects = torch.sum(labels_.cpu().data == preds_, dtype = torch.int)
        batch_accuracy = batch_corrects / len(labels_)
        epoch_corrects += batch_corrects
        epoch_total += len(labels_)
        epoch_accuracy = epoch_corrects / epoch_total         

        #Updating training displays.
        display.clear_output(wait=True)
        
        display.display('Epoch {} [{}/{} ({:.0f}%)]'.format(
                    epoch, batch_idx, 
                    len(dataloader), 
                    100. * (batch_idx / len(dataloader))))
        
        display.display('* batch accuracy {:.2f}% epoch accuracy {:.2f}%'.format(
                    100. * batch_accuracy, 100. * epoch_accuracy))
        
        display.display('* batch loss {:.6f} epoch loss {:.6f}'.format(
                    loss.item(), epoch_loss))
        display.display('* batch_corrects {}'.format(batch_corrects))
        
        if batch_idx > max_batches:
            break
    
    return epoch_loss, epoch_accuracy

In [67]:
model = QACorrelationPredictionModel(device = device)
_ = model.to(device)

In [68]:
batch_size = 64

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

learning_rate = 0.001

#for language models, Adam is a good option.The learning rate 
#typically less than 0.001 for stabability.
optimizer = torch.optim.Adam(
                        model.pred_head.parameters(), 
                        lr = learning_rate,
                        #momentum = 0.9, 
                        #weight_decay = 5e-4
                      )

#Loss function
criterion = torch.nn.BCELoss()

epoch_loss, epoch_accuracy = train_one_epoch(
          model, 
          device, 
          dataloader_train, 
          optimizer, 
          criterion,
          epoch = 1,
          max_batches = 500)



'* batch accuracy 50.00% epoch accuracy 50.09%'

'* batch loss 0.692932 epoch loss 0.698094'

'* batch_corrects 32'

KeyboardInterrupt: 

# Save/load model

In [52]:
import os

def save_model(model, model_path):
    
    save_path = os.path.normpath(os.path.dirname(model_path)).rstrip(os.path.sep)
        
    if not os.path.exists(save_path):
        os.makedirs(save_path)
        
    print("Save model weights to: ", model_path)
    torch.save(model.state_dict(), model_path)  

In [76]:
save_model(model, "models/qa_model_gpt2.pth")

Save model weights to:  models/qa_model_distilbert-base-uncased.pth


In [54]:
def load_model(model_path, device):
    model = QACorrelationPredictionModel(device)
    
    if os.path.exists(model_path):
        #re-loading
        model.load_state_dict(torch.load(model_path, map_location = device)) 
        print("Loaded model weights from: ", model_path)
    else:
        print("Model weights not found.")
        
    return model

In [55]:
model = load_model("models/qa_model_gpt2.pth", device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Loaded model weights from:  models/qa_model_distilbert-base-uncased.pth


# Complete training

In [57]:
def train(model, 
          device, 
          dataloader, 
          optimizer,
          criterion,
          epochs,
          scheduler = None):
    
    #if model_path is not None and not os.path.exists(model_path):
    #    os.makedirs(model_path)
    
    loss_hist = []
    accuracy_hist = []
    
    for epoch in range(1, epochs + 1):
        
        epoch_loss, epoch_accuracy = train_one_epoch(
          model, 
          device, 
          dataloader, 
          optimizer, 
          criterion,
          epoch,
          max_batches = None)
    
        if scheduler:
            #adjusting LR is necessary
            scheduler.step()
            
        loss_hist.append(epoch_loss)
        accuracy_hist.append(epoch_accuracy)
   
    return loss_hist, laccuracy_hist

In [58]:
batch_size = 64

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

learning_rate = 0.001

#for language models, Adam is a good option.The learning rate 
#typically less than 0.001 for stabability.
optimizer = torch.optim.Adam(
                        model.pred_head.parameters(), 
                        lr = learning_rate,
                        #momentum = 0.9, 
                        #weight_decay = 5e-4
                      )

#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
#                                                step_size = 10, #dropping learning-rate every 10 steps. 
#                                                gamma = 0.1)
scheduler = None

#Loss function
criterion = torch.nn.BCELoss()

In [59]:
loss_hist, laccuracy_hist = train(model, 
          device, 
          dataloader_train, 
          optimizer,
          criterion,
          epochs = 50,
          scheduler = scheduler)



'* batch accuracy 85.94% epoch accuracy 85.94%'

'* batch loss 0.331381 epoch loss 0.336008'

'* batch_corrects 55'

KeyboardInterrupt: 

In [None]:
#plt.plot(loss_hist)

In [None]:
#plt.plot(laccuracy_hist)

# Evaluation

In [60]:
def evaluate(model, 
             device, 
             dataloader,
             max_batches = None):
    
    # Disable gradient computing
    model.eval()
    
    if max_batches is None:
        max_batches = len(dataloader)
    
    batch_accuracy = 0.0
    total_accuracy = 0.0
    
    total_corrects = 0.0
    total_entries = 0.0
    
    for batch_idx, batch in enumerate(dataloader, 1):
        
        (q,a), labels = batch
        
        
        #no need to track gradients
        with torch.no_grad():
            preds = model(batch)
                        
        #computing accuracy in a batch
        preds_ = (preds >= 0.5).int().cpu().data
        
        #computing the correct and total samples
        batch_corrects = torch.sum(labels.cpu().data == preds_, dtype = torch.int)
        batch_accuracy = batch_corrects / len(labels)
        
        total_corrects += batch_corrects
        total_entries += len(labels)
        total_accuracy = total_corrects / total_entries         

        #Updating training displays.
        display.clear_output(wait=True)
        
        display.display('Evaluation: [{}/{} ({:.0f}%)]'.format(
                    batch_idx, 
                    len(dataloader), 
                    100. * (batch_idx / len(dataloader))))
        
        display.display('* batch accuracy {:.2f}% total accuracy {:.2f}%'.format(
                    100. * batch_accuracy, 100. * total_accuracy))
        
        display.display('* total_corrects {} total_entries {}'.format(total_corrects, total_entries))
        
        if batch_idx >= max_batches:
            break
                
    return total_accuracy

In [61]:
batch_size = 64

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

total_accuracy = evaluate(model, 
             device, 
             dataloader_test,
             max_batches = 10)

print(total_accuracy)



'* batch accuracy 87.50% total accuracy 82.03%'

'* total_corrects 525.0 total_entries 640.0'

tensor(0.8203)


# Predict

In [62]:
batch_size = 100

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True)

In [63]:
batch = next(iter(dataloader_test))

In [64]:
(q,a),labels = batch

In [65]:
preds = model(batch)

In [66]:
#prediction probability.
preds

tensor([2.0273e-02, 3.2407e-02, 3.5677e-01, 4.7347e-03, 2.5171e-01, 9.0781e-01,
        1.6453e-03, 1.4629e-01, 5.2216e-01, 9.1657e-01, 2.1846e-02, 9.8038e-02,
        6.6185e-01, 3.4565e-04, 7.9286e-02, 3.0783e-01, 1.4309e-01, 8.7267e-01,
        1.8307e-03, 1.2167e-01, 2.4902e-01, 7.6512e-01, 5.2146e-01, 9.2054e-01,
        4.6154e-01, 8.7516e-01, 1.1504e-02, 9.7972e-01, 8.2628e-01, 8.4387e-01,
        2.4516e-02, 8.2525e-01, 2.9915e-01, 9.1337e-01, 6.8162e-01, 9.4618e-01,
        7.4872e-01, 1.8747e-01, 3.7354e-01, 8.1918e-01, 8.5369e-01, 8.8730e-01,
        7.3736e-01, 4.4461e-03, 7.1582e-01, 6.9760e-01, 3.3097e-01, 7.3234e-01,
        9.0394e-02, 7.5479e-03, 1.1329e-01, 5.9142e-01, 7.5703e-01, 4.7870e-02,
        3.8350e-03, 3.4489e-02, 8.4882e-01, 8.2115e-01, 9.1722e-01, 7.1167e-03,
        6.7794e-04, 5.1792e-01, 8.1934e-01, 4.1382e-01, 3.2099e-02, 7.0505e-01,
        1.7920e-02, 5.4763e-02, 9.2899e-01, 7.5020e-01, 9.6094e-01, 5.8948e-01,
        7.6573e-01, 7.5728e-01, 6.5120e-

In [67]:
labels

tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1,
        0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0,
        1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1,
        0, 0, 1, 0])

In [68]:
(preds > 0.5).int()

tensor([0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1,
        0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1,
        0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1,
        0, 0, 1, 0], device='mps:0', dtype=torch.int32)

In [69]:
#average accuracy
torch.mean((preds > 0.5).int().cpu() == labels.int().cpu(), dtype = torch.float).cpu().item()

0.8100000023841858