In [None]:
!pip install torchmetrics -q

In [None]:
!pip install tqdm -q

In [None]:
!pip install datasets -q

In [None]:
!pip install nltk -q

In [None]:
!pip install gensim -q

In [None]:
import nltk
import torch
import sklearn
import datasets
import ipywidgets
import numpy as np
import torch.nn.functional as f
import gensim.downloader as api
import matplotlib.pyplot as plt
from torch import nn
from tqdm import tqdm, trange
from torchmetrics import Accuracy
from ipywidgets import FloatProgress
from torch.utils.data import DataLoader, TensorDataset

In [None]:
def encode(word):
    if word in word2idx.keys():
        return word2idx[word]
    
    return word2idx['unk']

def collate_fn(batch):
    max_len = max(len(row['features']) for row in batch)
    input_embeds = torch.empty((len(batch), max_len), dtype=torch.long) # матрица фичей для передачи в сеть
    labels = torch.empty(len(batch), dtype=torch.long)

    for idx, row in enumerate(batch):
        to_pad = max_len - len(row['features'])
        input_embeds[idx] = torch.cat((row['features'], torch.zeros(to_pad)))
        labels[idx] = row['label'] 

    return {'features': input_embeds, 'labels': labels}



In [None]:
# заморозка градиентов на первых N итерациях (для того, чтобы они не вносили неопределенность в веса)

def freeze_embeddings(model, req_grad=False):
    embeddings = model.embeddings
    for c_p in embeddings.parameters():
        c_p.requires_grad = req_grad

In [None]:
def train_network(model, criterion, optim, metric, num_epochs, loaders, max_grad_norm=2, num_freeze_iter=1000):
    freeze_embeddings(model) # чтобы только на 1 итерации была заморозка
    for e in tqdm(range(num_epochs)):
        model.train()
        num_iter = 0
        pbar = loaders['train']

        for batch in pbar:
            if num_iter > num_freeze_iter:
                freeze_embeddings(model, True)
            optimizer.zero_grad()
            input_embeds = batch['features'].to(device)
            labels = batch['labels'].to(device)
            pred = model(input_embeds)
            loss = criterion(pred, labels)
            
            loss.backward()

            if max_grad_norm:
                torch.nn.utils.clip_grad_norm(model.parameters(), max_grad_norm)
            
            optimizer.step()
            num_iter += 1
            input_embeds.to('cpu')
            labels.to('cpu')
            torch.cuda.empty_cache()

        valid_loss = 0
        valid_acc = 0
        num_iter = 0
        model.eval()

        with torch.no_grad():
            for batch in loaders['test']:
                input_embeds = batch['features'].to(device)
                labels = batch['labels'].to(device)
                pred = model(input_embeds)

                valid_loss += criterion(pred, labels)
                valid_acc += metric(pred, labels)
                num_iter += 1
        
        print(f'Valid Loss: {valid_loss / num_iter}, Accuracy: {valid_acc/num_iter}')

In [None]:
# использование предобученных эмбеддингов
# (перед этим передать в model нужную архитектуру, обучить, вызвать этот блок и еще раз обучить)
with torch.no_grad():
    for word, idx in word2idx.items():
        if word in word2vec:
            model.embedding.weight[idx] = torch.from_numpy(word2vec.get_vector(word))

In [None]:
SEED = 0xDEAD

np.random.seed(SEED)
torch.random.manual_seed(SEED)
torch.cuda.random.manual_seed_all(SEED)

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')

In [None]:
dataset = datasets.load_dataset('ag_news')

In [None]:
tokenizer = nltk.WordPunctTokenizer()
max_length = 128

dataset = dataset.map(lambda x: {
    'tokenized': tokenizer.tokenize(x['text'])[:max_length]
})

In [None]:
word2vec = api.load('glove-twitter-50')

In [None]:
len(loaders['train'])

In [None]:
word2idx = {word: ind for ind, word in enumerate(word2vec.index_to_key)}

In [None]:
dataset = dataset.map(lambda x:{
    'features': [encode(word) for word in x['tokenized']]
})

In [None]:
dataset = dataset.remove_columns(['text', 'tokenized'])

In [None]:
dataset.set_format(type='torch')

In [None]:
loaders = {k: DataLoader(ds, shuffle=(k=='train'), batch_size=32, collate_fn = collate_fn)
for k, ds in dataset.items()}

# Сверточная нейросеть

In [None]:
class CNN_Model(nn.Module):
    def __init__(self, embed_size, hidden_size, num_classes=4):
        super().__init__()

        self.embeddings = nn.Embedding(len(word2idx), embed_size) # инициализация эмбеддингов для всех слов из словаря
        self.cnn = nn.Sequential(
            nn.Conv1d(embed_size, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Conv1d(hidden_size, hidden_size, kernel_size=3, padding=1, stride=2),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )

        self.cls = nn.Sequential(
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        x = self.embeddings(x)
        x = x.permute(0, 2, 1)
        x = self.cnn(x)
        pred = self.cls(x)

        return pred

## Инициализация

In [None]:
model_cnn = CNN_Model(word2vec.vector_size, 50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_cnn.parameters(), lr=1e-2)
metric_cnn = Accuracy('multiclass', num_classes=4).to(device)

## Обучение

# Классическая рекуррентная нейросеть

In [None]:
class RNN_block(nn.Module):
    def __init__(self, embed_size, hidden_size):
        super().__init__()
        
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        
        self.W = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.U = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.V = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_x = nn.Parameter(torch.rand(1, hidden_size))
        self.b_h = nn.Parameter(torch.rand(1, hidden_size))
        
    def forward(self, x, hidden=None):
        hidden = torch.zeros((x.size(0), self.hidden_size)).to(x.device) # h(t-1) размер батча x размер скрытого состояния
        seq_len = x.size(1) # длина max предложения
        
        if hidden is None:
            for cur_idx in range(seq_len): # обновляем hidden по каждому номеру слова каждого предл-я в батче
                hidden = torch.tanh(x[:, cur_idx] @ self.W + hidden @ self.U + self.b_h)
#         print(hidden.is_cuda, self.V.is_cuda, self.b_x.is_cuda)
        res = torch.tanh(hidden @ self.V + self.b_x)
        return res 
            
            
         

In [None]:
class RNN_Model(nn.Module):
    def __init__(self, embed_size, hidden_size, num_classes=4):
        super().__init__()
        self.embeddings = nn.Embedding(len(word2idx), embed_size)
        self.rnn = RNN_block(embed_size, hidden_size)
        self.cls = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.embeddings(x)
        hidden = self.rnn(x)
        output = self.cls(hidden)
        return output
        
        

In [None]:
model_rnn = RNN_Model(word2vec.vector_size, 50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_rnn.parameters(), lr=1e-2)
metric_rnn = Accuracy('multiclass', num_classes=4).to(device)

In [None]:
train_network(model_cnn, criterion, optimizer, metric_cnn, 1, loaders)

In [None]:
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

In [None]:
train_network(model_rnn, criterion, optimizer, metric, 1, loaders)

# GRU (модификация RNN)

In [None]:
class GRU(nn.Module):
    def __init__(self, embed_size, hidden_size):
        super().__init__()
        
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        
        self.w_rh = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_rh = nn.Parameter(torch.rand(1, hidden_size))
        self.w_rx = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.b_rx = nn.Parameter(torch.rand(1, hidden_size))
        
        self.w_zh = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_zh = nn.Parameter(torch.rand(1, hidden_size))
        self.w_zx = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.b_zx = nn.Parameter(torch.rand(1, hidden_size))
        
        self.w_nh = nn.Parameter(torch.rand(hidden_size, hidden_size))
        self.b_nh = nn.Parameter(torch.rand(1, hidden_size))
        self.w_nx = nn.Parameter(torch.rand(embed_size, hidden_size))
        self.b_nx = nn.Parameter(torch.rand(1, hidden_size))
        
    def forward(self, x, hidden=None):
        
        if hidden is None:
            hidden = torch.zeros((x.size(0), self.hidden_size)).to(x.device)
        
        seq_len = x.size(1) 
        for cur_idx in range(seq_len):
            r = torch.sigmoid(x[:, cur_idx] @ self.w_rx + self.b_rx + hidden @ self.w_rh + self.b_rh)
            z = torch.sigmoid(x[:, cur_idx] @ self.w_zx + self.b_zx + hidden @ self.w_zh + self.b_zh)
            n = torch.tanh(x[:, cur_idx] @ self.w_nx + self.b_nx + r * (hidden @ self.w_nh + self.b_nh))
            hidden = (1 - z) * n + z * hidden
        
        return hidden
        

In [None]:
class GRU_Model(nn.Module):
    def __init__(self, embed_size, hidden_size, num_classes=4):
        super().__init__()
        self.embeddings = nn.Embedding(len(word2idx), embed_size)
        self.gru = GRU(embed_size, hidden_size)
        self.cls = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        x = self.embeddings(x)
        hidden = self.gru(x)
        output = self.cls(hidden)
        return output

In [None]:
model_gru = GRU_Model(embed_size=word2vec.vector_size, hidden_size=50).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_gru.parameters(), lr=1e-2)
metric = Accuracy('multiclass', num_classes=4).to(device)

In [None]:
train_network(model_gru, criterion, optimizer, metric, 1, loaders)

In [None]:
torch.cuda.empty_cache()

In [None]:
!nvidia-smi