In [38]:
import urllib
import os
import collections
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
import torchtext
from torch.utils.data import Dataset, DataLoader, random_split

In [39]:
os.makedirs('data', exist_ok=True)
dataset_train, dataset_test = torchtext.datasets.AG_NEWS(root='./data')
tokenizer = torchtext.data.utils.get_tokenizer('basic_english', language="en")
dataset_train = list(dataset_train)
dataset_test = list(dataset_test)

In [40]:
tokenized_data = []
label_data = []
counter = collections.Counter()

print('Start Tokenizing...')
for label, line in tqdm(dataset_train):
    tokens = tokenizer(line)
    tokenized_data.append(tokens)
    label_data.append(label-1)

print('Making Vocab...')
for line in tqdm(tokenized_data):
    counter.update(line)

vocab = torchtext.vocab.vocab(counter, min_freq=1, specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])

Start Tokenizing...


100%|██████████| 120000/120000 [00:04<00:00, 27084.26it/s]


Making Vocab...


100%|██████████| 120000/120000 [00:01<00:00, 98035.04it/s]


In [41]:
class AGNewsDataset(Dataset):
    def __init__(self, tokenized_data, label_data, max_seq = 256):
        self.x = []
        self.y = []
        self.classes = ['World', 'Sports', 'Business', 'Sci/Tech']
        
        for tokens in tqdm(tokenized_data):
            token_ids = [vocab[token] for token in tokens]
            self.x.append(token_ids)
        self.y = label_data
        
    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, idx):
        return torch.tensor(self.x[idx]) , torch.tensor(self.y[idx])

In [42]:
news_dataset = AGNewsDataset(tokenized_data, label_data, max_seq=256)

100%|██████████| 120000/120000 [00:04<00:00, 26562.33it/s]


In [43]:
train_ratio = 0.8
train_size = int(train_ratio*len(news_dataset))
valid_size = len(news_dataset) - train_size
train_dataset, valid_dataset = random_split(news_dataset, [train_size, valid_size])

In [44]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
        self.out =  nn.Linear(hidden_size, output_size)

    def forward(self, input):
        input = self.embedding(input)
        #input: [batch_size, sentence_len, emb_size]
        output, (_, _) = self.lstm(input)
        #output: [batch_size, sentence_len, hidden_size * direction]
        output = self.out(output)
        #output: [batch_size, sentence_len, output_size]
        output = output.transpose(1,0)
        #output: [sentence_len, batch_size, output_size]
        output = output[-1]
        #output: [batch_size, output_size]

        return output



In [45]:
model = LSTM(vocab_size=len(vocab), embedding_size=64, hidden_size=32, output_size=len(news_dataset.classes))

In [57]:
batch_size=256
learning_rate = 0.001
num_epochs = 1
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [58]:
def pad_collate_fn(batch):
    collate_x = []
    collate_y = []
    for data, label in batch:
        collate_x.append(data)
        collate_y.append(label)
    collate_x = nn.utils.rnn.pad_sequence(collate_x, padding_value=vocab['<pad>'], batch_first=True)
    collate_y = torch.stack(collate_y)
    return (collate_x, collate_y)



In [59]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn= pad_collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn= pad_collate_fn)

In [60]:
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [73]:
def accuracy(output, label):
    pred = torch.argmax(output, dim=-1)
    acc = torch.sum(pred==label)
    return acc
    

In [74]:
def train(dataloader, epoch):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    data_num = 0
    for data,label in tqdm(dataloader, desc=f"Epoch {epoch}"):
        data, label = data.to(device), label.to(device)
        output = model(data)
        optimizer.zero_grad()
        loss = criterion(output, label)
        acc = accuracy(output, label)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        data_num += data.size(0)

    print(f"Train Epoch: {epoch}, Loss: {epoch_loss/len(dataloader)}, Acc: {epoch_acc/len(data_num)}")


In [75]:
def evaluate(dataloader, epoch):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    data_num = 0
    with torch.no_grad():
        for data,label in tqdm(dataloader, desc=f"Epoch {epoch}"):
            data, label = data.to(device), label.to(device)
            output = model(data)
            loss = criterion(output, label)
            acc = accuracy(output, label)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
            data_num += data.size(0)

    print(f"Evaluate Epoch: {epoch}, Loss: {epoch_loss/len(dataloader)}, Acc: {epoch_acc/len(data_num)}")
    return epoch_acc/len(data_num)

In [76]:
total_acc = 0
best_acc = 0
for epoch in range(num_epochs):
    train(train_dataloader, epoch)
    acc = evaluate(valid_dataloader, epoch)
    best_acc = max(acc, best_acc)
    print('-'*50)

print(f"Best Accuracy: {best_acc}")

Epoch 0:   2%|▏         | 7/375 [00:00<00:10, 34.27it/s]

tensor([0, 2, 1, 2, 2, 0, 3, 3, 3, 2, 1, 3, 3, 2, 2, 2, 3, 3, 0, 0, 0, 0, 0, 3,
        1, 2, 3, 3, 2, 2, 3, 1, 0, 0, 3, 3, 2, 3, 3, 2, 2, 1, 2, 3, 1, 2, 3, 1,
        3, 1, 3, 3, 3, 3, 2, 0, 2, 2, 2, 3, 3, 2, 0, 0, 3, 0, 1, 2, 2, 1, 0, 3,
        1, 0, 3, 0, 3, 1, 1, 3, 0, 3, 1, 0, 3, 1, 1, 1, 1, 3, 0, 3, 0, 3, 1, 1,
        3, 1, 0, 1, 3, 3, 3, 3, 2, 3, 0, 3, 1, 0, 1, 3, 1, 0, 0, 0, 3, 2, 3, 3,
        2, 1, 2, 1, 3, 0, 2, 3, 0, 2, 2, 0, 2, 3, 2, 3, 1, 0, 2, 2, 3, 2, 1, 0,
        2, 1, 0, 1, 2, 1, 1, 3, 1, 3, 3, 2, 2, 0, 1, 2, 1, 1, 2, 1, 1, 3, 3, 3,
        3, 1, 1, 1, 0, 1, 3, 2, 1, 3, 1, 3, 2, 3, 2, 1, 2, 2, 2, 0, 2, 0, 1, 3,
        0, 1, 1, 3, 1, 0, 1, 1, 2, 2, 3, 1, 1, 0, 0, 1, 1, 0, 2, 1, 1, 1, 3, 1,
        2, 0, 3, 2, 0, 2, 3, 3, 3, 1, 2, 1, 0, 3, 2, 1, 0, 3, 1, 1, 3, 0, 3, 2,
        0, 3, 0, 2, 1, 1, 3, 1, 0, 0, 3, 1, 1, 2, 1, 2], device='cuda:0')
tensor([0, 2, 1, 2, 2, 0, 3, 0, 3, 2, 1, 3, 0, 2, 2, 2, 3, 3, 0, 0, 0, 0, 0, 3,
        1, 2, 3, 3, 2, 2, 3, 1, 2, 0, 3, 3, 2,

Epoch 0:   4%|▍         | 16/375 [00:00<00:09, 39.63it/s]

tensor([3, 2, 2, 3, 0, 0, 2, 2, 3, 0, 2, 2, 0, 1, 2, 2, 3, 0, 1, 3, 3, 3, 1, 1,
        0, 2, 2, 3, 2, 3, 0, 3, 1, 1, 2, 0, 1, 3, 0, 3, 0, 2, 2, 2, 3, 1, 1, 1,
        2, 1, 3, 1, 2, 1, 2, 2, 2, 0, 1, 0, 1, 0, 0, 2, 3, 0, 0, 2, 0, 2, 0, 0,
        1, 0, 1, 1, 0, 2, 1, 3, 2, 3, 3, 2, 1, 2, 3, 3, 0, 1, 2, 2, 3, 1, 1, 0,
        0, 1, 2, 2, 2, 1, 3, 3, 3, 1, 1, 1, 3, 1, 1, 1, 0, 2, 2, 3, 2, 2, 1, 0,
        3, 2, 1, 0, 3, 2, 2, 0, 3, 3, 2, 0, 1, 2, 3, 1, 0, 3, 0, 1, 1, 0, 0, 2,
        0, 1, 1, 0, 1, 3, 0, 3, 3, 1, 3, 3, 2, 0, 1, 1, 1, 0, 1, 2, 2, 2, 3, 3,
        3, 1, 2, 3, 3, 3, 0, 0, 0, 1, 1, 0, 2, 1, 1, 2, 3, 2, 3, 1, 3, 0, 1, 2,
        3, 3, 3, 0, 3, 0, 2, 3, 3, 3, 2, 3, 2, 1, 1, 2, 2, 0, 3, 3, 1, 3, 0, 2,
        2, 3, 2, 3, 0, 2, 0, 1, 1, 0, 3, 3, 0, 1, 3, 0, 1, 0, 1, 3, 2, 3, 3, 1,
        3, 2, 0, 0, 3, 3, 0, 3, 0, 2, 0, 1, 2, 3, 2, 2], device='cuda:0')
tensor([3, 2, 2, 3, 0, 0, 2, 2, 3, 0, 2, 2, 0, 1, 2, 2, 3, 0, 1, 3, 3, 3, 1, 1,
        0, 2, 2, 3, 2, 3, 0, 3, 1, 1, 2, 0, 1,

Epoch 0:   6%|▋         | 24/375 [00:00<00:09, 36.09it/s]

tensor([2, 0, 3, 0, 2, 0, 2, 1, 1, 2, 1, 1, 0, 0, 1, 1, 3, 3, 2, 2, 2, 3, 2, 2,
        2, 2, 2, 0, 0, 1, 0, 0, 1, 2, 0, 1, 1, 1, 2, 0, 1, 2, 3, 2, 2, 0, 1, 1,
        0, 3, 3, 2, 2, 2, 0, 2, 2, 1, 1, 3, 1, 0, 1, 1, 1, 1, 3, 1, 1, 0, 3, 1,
        0, 1, 0, 1, 3, 2, 2, 0, 3, 0, 0, 1, 0, 2, 3, 1, 3, 3, 3, 1, 1, 1, 2, 0,
        2, 0, 1, 0, 0, 1, 2, 3, 1, 0, 1, 3, 0, 0, 1, 1, 3, 0, 2, 2, 1, 0, 2, 0,
        0, 0, 3, 0, 2, 1, 0, 1, 2, 3, 0, 1, 1, 2, 0, 2, 1, 3, 1, 2, 3, 0, 1, 3,
        1, 0, 3, 3, 2, 1, 1, 3, 0, 3, 3, 2, 1, 3, 3, 2, 1, 3, 0, 3, 2, 2, 1, 2,
        0, 3, 1, 1, 2, 3, 1, 3, 3, 3, 0, 1, 2, 1, 2, 0, 2, 1, 0, 0, 1, 3, 3, 1,
        2, 3, 0, 2, 3, 3, 2, 2, 2, 3, 1, 1, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 3, 3,
        2, 0, 3, 3, 0, 3, 3, 1, 1, 1, 2, 2, 1, 3, 0, 0, 0, 3, 0, 0, 0, 1, 1, 0,
        2, 3, 1, 1, 3, 1, 2, 0, 3, 0, 1, 1, 2, 3, 1, 0], device='cuda:0')
tensor([2, 0, 3, 0, 2, 0, 2, 0, 1, 2, 1, 1, 0, 0, 1, 1, 3, 3, 2, 2, 2, 3, 2, 2,
        3, 2, 2, 0, 0, 1, 0, 0, 1, 2, 3, 1, 1,

Epoch 0:   7%|▋         | 28/375 [00:00<00:09, 35.20it/s]

tensor([3, 1, 0, 0, 0, 3, 3, 3, 2, 3, 2, 0, 1, 1, 0, 3, 2, 2, 3, 3, 1, 2, 3, 0,
        1, 0, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, 0, 1, 1, 0, 3, 1, 2, 2, 3, 1, 1, 2,
        3, 0, 2, 0, 3, 1, 0, 1, 1, 1, 0, 1, 3, 1, 3, 2, 1, 3, 2, 2, 1, 1, 1, 3,
        1, 2, 3, 3, 2, 2, 0, 1, 2, 1, 2, 3, 3, 2, 0, 2, 1, 3, 2, 1, 3, 1, 0, 0,
        2, 0, 2, 1, 3, 3, 0, 2, 1, 1, 2, 0, 2, 1, 3, 2, 1, 1, 3, 2, 2, 0, 0, 3,
        1, 0, 0, 3, 1, 0, 1, 1, 0, 1, 1, 2, 2, 2, 3, 0, 3, 2, 2, 3, 1, 3, 2, 1,
        1, 3, 3, 3, 1, 0, 0, 2, 0, 0, 2, 3, 2, 0, 1, 2, 1, 1, 0, 3, 1, 1, 0, 3,
        1, 0, 0, 0, 3, 1, 3, 2, 2, 3, 3, 2, 2, 1, 2, 2, 3, 1, 2, 2, 2, 3, 3, 2,
        3, 3, 2, 0, 0, 0, 0, 0, 3, 2, 1, 3, 1, 1, 1, 0, 3, 0, 0, 1, 2, 3, 2, 3,
        3, 3, 2, 1, 0, 0, 0, 0, 3, 1, 1, 2, 1, 1, 2, 0, 2, 3, 0, 3, 3, 1, 1, 3,
        1, 0, 3, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 3, 0, 3], device='cuda:0')
tensor([3, 1, 0, 0, 0, 3, 3, 3, 2, 3, 2, 0, 1, 1, 0, 3, 2, 2, 3, 3, 1, 0, 3, 0,
        1, 0, 2, 3, 3, 3, 2, 2, 2, 0, 3, 3, 0,

Epoch 0:  10%|▉         | 36/375 [00:01<00:09, 34.11it/s]

tensor([0, 2, 3, 1, 1, 0, 0, 2, 2, 0, 3, 1, 2, 0, 2, 0, 1, 1, 2, 0, 1, 3, 0, 1,
        1, 3, 2, 3, 0, 0, 2, 3, 2, 1, 2, 2, 1, 0, 1, 0, 0, 2, 1, 2, 0, 2, 2, 3,
        3, 3, 3, 2, 0, 3, 3, 1, 1, 2, 3, 2, 1, 3, 3, 1, 2, 2, 3, 0, 2, 1, 1, 1,
        2, 3, 1, 0, 2, 1, 1, 2, 2, 1, 3, 1, 3, 2, 3, 2, 1, 0, 2, 2, 1, 1, 0, 2,
        1, 2, 1, 3, 1, 2, 0, 2, 3, 2, 0, 3, 3, 0, 2, 3, 1, 1, 3, 3, 3, 2, 1, 3,
        1, 0, 1, 3, 2, 1, 3, 0, 3, 2, 0, 2, 0, 0, 3, 3, 1, 2, 1, 3, 2, 2, 3, 2,
        3, 0, 2, 0, 2, 3, 1, 1, 0, 1, 3, 1, 0, 1, 1, 3, 0, 3, 0, 0, 0, 2, 0, 0,
        3, 1, 1, 1, 3, 0, 2, 0, 1, 2, 0, 2, 3, 0, 3, 0, 0, 0, 2, 3, 2, 3, 3, 0,
        3, 3, 0, 0, 0, 3, 0, 2, 0, 0, 1, 3, 1, 2, 3, 0, 0, 1, 3, 3, 3, 0, 1, 2,
        3, 1, 1, 3, 0, 1, 1, 3, 0, 1, 3, 3, 0, 0, 0, 1, 2, 0, 1, 3, 3, 3, 3, 3,
        0, 0, 2, 2, 3, 1, 0, 2, 3, 0, 3, 0, 3, 1, 3, 1], device='cuda:0')
tensor([0, 2, 3, 1, 1, 0, 0, 2, 2, 0, 3, 1, 2, 0, 2, 0, 1, 1, 2, 0, 1, 0, 2, 1,
        1, 3, 2, 3, 2, 0, 2, 3, 2, 1, 2, 2, 1,

Epoch 0:  12%|█▏        | 44/375 [00:01<00:09, 33.50it/s]

tensor([2, 3, 2, 3, 0, 2, 1, 3, 2, 3, 1, 1, 3, 1, 2, 1, 2, 3, 1, 2, 3, 0, 2, 1,
        2, 3, 0, 1, 3, 3, 1, 3, 3, 2, 2, 0, 1, 0, 2, 3, 0, 2, 1, 3, 1, 3, 0, 2,
        2, 3, 3, 1, 0, 0, 2, 3, 0, 1, 1, 1, 2, 1, 1, 0, 2, 1, 0, 0, 2, 0, 0, 1,
        2, 2, 3, 1, 2, 3, 3, 0, 3, 2, 0, 1, 3, 3, 0, 1, 0, 3, 3, 2, 3, 1, 1, 1,
        0, 0, 2, 0, 0, 0, 0, 3, 1, 1, 3, 1, 1, 2, 3, 1, 2, 1, 0, 1, 3, 2, 2, 1,
        0, 0, 1, 2, 0, 3, 3, 2, 1, 0, 1, 2, 1, 2, 3, 2, 1, 3, 1, 2, 0, 3, 2, 0,
        1, 0, 3, 2, 3, 1, 1, 0, 0, 0, 3, 0, 2, 0, 2, 0, 3, 0, 3, 2, 3, 3, 1, 0,
        3, 0, 3, 1, 3, 1, 2, 0, 1, 1, 2, 0, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        2, 0, 0, 3, 0, 2, 2, 1, 1, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 0, 2, 0, 2, 2,
        1, 0, 0, 2, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 2, 3, 3, 3, 3, 2,
        1, 0, 3, 3, 2, 1, 0, 1, 3, 2, 1, 0, 0, 0, 0, 0], device='cuda:0')
tensor([2, 3, 2, 3, 0, 2, 1, 3, 2, 3, 1, 1, 3, 1, 2, 1, 2, 3, 1, 2, 3, 0, 2, 1,
        2, 3, 0, 1, 3, 3, 1, 3, 3, 2, 2, 0, 0,

Epoch 0:  14%|█▍        | 52/375 [00:01<00:10, 32.12it/s]

tensor([0, 2, 0, 0, 2, 1, 3, 1, 0, 0, 2, 2, 2, 2, 3, 0, 1, 3, 2, 3, 2, 2, 2, 1,
        3, 2, 2, 2, 0, 1, 2, 0, 1, 2, 2, 0, 0, 0, 0, 1, 0, 2, 3, 3, 2, 3, 2, 2,
        0, 2, 3, 3, 2, 0, 2, 2, 2, 1, 0, 2, 2, 0, 1, 1, 2, 0, 0, 1, 2, 3, 1, 1,
        1, 2, 3, 1, 0, 0, 3, 1, 3, 2, 1, 2, 0, 0, 0, 1, 2, 1, 3, 1, 1, 1, 2, 2,
        1, 3, 2, 3, 3, 2, 1, 3, 2, 1, 1, 3, 2, 0, 3, 3, 1, 3, 3, 3, 0, 1, 0, 3,
        2, 3, 3, 2, 0, 1, 1, 1, 2, 0, 2, 3, 0, 3, 1, 3, 0, 3, 3, 1, 3, 2, 1, 1,
        0, 3, 0, 0, 2, 3, 1, 2, 0, 2, 1, 1, 1, 1, 1, 1, 3, 0, 2, 2, 1, 3, 2, 0,
        2, 1, 2, 3, 0, 3, 2, 3, 0, 0, 0, 3, 3, 2, 0, 3, 2, 2, 0, 2, 2, 2, 3, 3,
        1, 1, 0, 0, 0, 3, 3, 3, 3, 2, 2, 1, 3, 1, 3, 2, 0, 2, 0, 2, 3, 0, 2, 3,
        2, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0, 2, 2, 1, 1, 3, 0, 2, 3, 1, 0, 2, 0, 2,
        3, 3, 2, 0, 0, 0, 0, 0, 1, 0, 0, 2, 2, 3, 2, 3], device='cuda:0')
tensor([0, 0, 3, 0, 2, 1, 3, 1, 0, 3, 2, 2, 2, 2, 3, 0, 1, 3, 2, 3, 2, 2, 2, 1,
        0, 2, 2, 2, 0, 1, 2, 0, 1, 2, 2, 0, 0,

Epoch 0:  15%|█▍        | 56/375 [00:01<00:09, 32.46it/s]

tensor([0, 1, 3, 3, 0, 1, 0, 1, 0, 0, 3, 2, 0, 1, 0, 1, 3, 0, 3, 0, 3, 2, 3, 1,
        1, 3, 2, 3, 1, 0, 0, 2, 1, 0, 1, 2, 3, 2, 1, 0, 2, 1, 2, 2, 2, 1, 3, 3,
        3, 2, 1, 3, 2, 1, 3, 0, 2, 2, 3, 3, 0, 0, 0, 3, 1, 1, 3, 3, 0, 0, 1, 3,
        3, 3, 0, 0, 2, 3, 2, 1, 3, 3, 3, 0, 2, 1, 3, 2, 0, 3, 3, 2, 3, 1, 3, 1,
        0, 2, 2, 3, 1, 2, 2, 1, 3, 1, 3, 0, 0, 3, 2, 1, 3, 1, 0, 1, 1, 0, 1, 0,
        3, 0, 1, 3, 2, 0, 1, 3, 3, 0, 0, 1, 3, 3, 2, 3, 3, 3, 1, 3, 0, 2, 1, 2,
        0, 2, 1, 3, 2, 2, 3, 0, 0, 2, 0, 0, 3, 3, 2, 2, 2, 2, 1, 3, 3, 1, 1, 1,
        2, 3, 1, 1, 2, 1, 0, 3, 2, 2, 3, 1, 0, 3, 1, 2, 2, 2, 2, 1, 1, 0, 2, 3,
        2, 3, 2, 0, 2, 3, 0, 1, 2, 3, 3, 3, 0, 0, 0, 3, 1, 1, 1, 0, 1, 2, 0, 1,
        1, 2, 3, 3, 1, 1, 0, 0, 3, 0, 3, 1, 2, 2, 2, 0, 0, 0, 3, 1, 0, 1, 3, 1,
        2, 3, 2, 3, 3, 1, 1, 3, 0, 2, 0, 2, 0, 2, 3, 0], device='cuda:0')
tensor([0, 1, 3, 3, 0, 1, 0, 1, 0, 0, 3, 2, 0, 1, 0, 1, 3, 0, 3, 0, 3, 2, 3, 1,
        1, 3, 2, 3, 1, 0, 0, 2, 1, 0, 1, 2, 3,

Epoch 0:  17%|█▋        | 64/375 [00:01<00:09, 32.77it/s]

tensor([1, 3, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 2, 3, 1, 2, 0, 1, 0, 1,
        1, 1, 0, 3, 3, 3, 0, 0, 1, 3, 0, 1, 3, 2, 0, 1, 3, 2, 0, 0, 3, 2, 3, 1,
        2, 3, 3, 1, 3, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 2, 3, 2, 0, 2, 1, 2,
        1, 1, 2, 0, 0, 3, 2, 0, 1, 1, 3, 3, 2, 2, 0, 1, 3, 2, 1, 2, 0, 1, 2, 1,
        1, 1, 0, 0, 2, 3, 3, 0, 1, 1, 1, 0, 0, 3, 3, 1, 2, 3, 1, 0, 1, 2, 3, 0,
        0, 2, 2, 3, 3, 2, 3, 0, 3, 0, 2, 3, 2, 0, 3, 2, 0, 1, 3, 2, 0, 1, 0, 3,
        3, 0, 0, 3, 1, 0, 2, 0, 3, 3, 0, 1, 3, 2, 0, 2, 3, 1, 3, 0, 1, 3, 1, 3,
        0, 3, 0, 3, 3, 3, 2, 3, 3, 3, 3, 2, 0, 3, 0, 0, 3, 0, 1, 2, 2, 0, 1, 1,
        3, 2, 1, 1, 3, 3, 0, 0, 2, 2, 3, 0, 0, 0, 3, 3, 0, 1, 0, 2, 2, 3, 3, 3,
        1, 2, 0, 2, 3, 2, 2, 2, 1, 0, 3, 1, 3, 2, 1, 2, 1, 3, 0, 1, 0, 3, 0, 3,
        1, 1, 1, 3, 0, 3, 2, 3, 2, 1, 2, 1, 0, 3, 0, 3], device='cuda:0')
tensor([1, 3, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 2, 3, 1, 2, 0, 1, 0, 1,
        1, 1, 3, 0, 3, 3, 0, 0, 1, 3, 0, 1, 3,

Epoch 0:  18%|█▊        | 68/375 [00:02<00:09, 33.47it/s]


tensor([3, 3, 0, 2, 0, 1, 1, 2, 1, 3, 0, 0, 2, 1, 3, 2, 1, 1, 1, 3, 2, 0, 2, 1,
        1, 1, 3, 2, 2, 3, 2, 2, 2, 0, 2, 1, 0, 1, 1, 3, 0, 1, 2, 2, 0, 0, 1, 2,
        0, 1, 2, 2, 0, 1, 0, 3, 0, 2, 0, 3, 3, 0, 3, 1, 2, 1, 0, 3, 0, 3, 3, 0,
        3, 2, 2, 1, 1, 0, 2, 0, 3, 2, 0, 1, 3, 1, 1, 1, 2, 3, 3, 0, 3, 0, 0, 1,
        0, 1, 1, 2, 1, 2, 0, 2, 1, 1, 1, 3, 0, 1, 3, 2, 2, 2, 0, 0, 3, 1, 1, 2,
        2, 3, 2, 1, 2, 3, 0, 3, 1, 1, 1, 1, 0, 1, 2, 3, 0, 1, 1, 2, 3, 2, 1, 0,
        1, 1, 2, 2, 0, 3, 0, 0, 3, 3, 2, 3, 2, 3, 2, 2, 2, 0, 2, 0, 2, 1, 1, 1,
        1, 0, 0, 2, 1, 1, 2, 0, 2, 3, 2, 2, 1, 3, 1, 1, 0, 0, 3, 2, 1, 1, 1, 3,
        2, 3, 3, 0, 1, 3, 3, 0, 1, 3, 1, 0, 2, 1, 0, 2, 3, 1, 3, 3, 3, 1, 2, 0,
        1, 3, 3, 2, 0, 0, 1, 1, 2, 2, 2, 3, 2, 3, 3, 3, 3, 3, 0, 1, 1, 3, 1, 2,
        1, 2, 2, 2, 3, 1, 0, 1, 2, 0, 3, 0, 2, 1, 2, 3], device='cuda:0')
tensor([3, 3, 0, 2, 0, 1, 1, 2, 1, 3, 0, 0, 0, 1, 3, 2, 1, 1, 1, 3, 2, 0, 2, 1,
        1, 1, 3, 3, 2, 3, 2, 2, 2, 0, 2, 1, 0,

KeyboardInterrupt: 