In [25]:
import torch

from torch import nn

import torch.nn.functional as F

In [None]:
class Dataset(torch dataset):

        def __init__(self, data, max_len, pad_index, eos_index):
            
        super().__init__()
        
        self.data = data
        
        self.max_len = max_len
        
        self.pad_index = pad_index
        self.eos_index = eos_index

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        
        sequence = self.data[index][:self.max_len]
        
        # исходная последовательность
        x = sequence[:]
        # нужно предсказать смещенную последовательность
        y = sequence[1:] + [self.eos_index]
        
        assert len(x) == len(y)
        
        pads = [self.pad_index] * (self.max_len - len(x))
        
        x = torch.tensor(x + pads).long()
        y = torch.tensor(y + pads).long()
        
        return x, y

In [27]:
class Attention(nn.Module):
    
    def init(self, dim):
        
        super().__init__(self)
        
        self.dim = dim
        self.query_w = nn.Linear(self.dim, self.dim)
        self.key_w = nn.Linear(self.dim, self.dim)
        self.value_w = nn.Linear(self.dim, self.dim)
        
    def forward(self, batch):
        
        queries = self.query_w(batch)
        keys = self.key_w(batch)
        values = self.value_w(batch)
        
        attention_matrix = (queries @ torch.transpose(keys, 1, 2)) / torch.sqrt(self.dim)
        attention_matrix = F.softmax(attention_matrix)

        weighted_vectors = values @ attention_matrix
        
        return weighted_vectors

In [28]:
class DecoderBlock(nn.Module):
    
    def forward(self, batch):
        after_attention = self.Attention(batch)
        batch += after_attention
        #нормализация
        after_feedforward = self.feedforward(batch)
        batch += after_feedforward
        #нормализация
        
        return batch

In [None]:
class ChataboxModel(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers):
        
        super().__init__()
        
        self.emb_layer = nn.Embedding(vocab_size, embedding_dim)
        self.decoders = nn.Sequential({})
        
    def forward(self, batch):
        

In [None]:
def predict(model, batch, multigpu_mode, device, inference=False):
        
        predictions = model(batch)
        
        return predictions

In [None]:
def train(
        model,
        iterator,
        optimizer,
        criterion,
        print_every=10,
        epoch=0,
        device="cpu",
    ):

        print(f"epoch {epoch}")

        epoch_loss = 0

        model.train()

        for i, batch in enumerate(iterator):

            optimizer.zero_grad()

            predictions, ys = predict(model, batch, device)

            loss = criterion(predictions.float(), ys.float())
            loss.backward()

            optimizer.step()

            batch_loss = loss.item()
            epoch_loss += batch_loss

            if not (i + 1) % print_every:
                print(f"step {i} from {len(iterator)} at epoch {epoch}")
                print(f"Loss: {batch_loss}")

        return epoch_loss / len(iterator)


In [None]:
def evaluate(model, iterator, criterion, epoch=0, device="cpu", save_checkpoints=True, timestamp=None):

    print(f"epoch {epoch} evaluation")

    epoch_loss = 0

    #    model.train(False)
    model.eval()

    with torch.no_grad():
        for batch in tqdm(iterator):

            predictions, ys = predict(model, batch, multigpu_mode, device)

            loss = criterion(predictions.float(), ys.to(device).float())

            epoch_loss += loss.item()

    overall_loss = epoch_loss / len(iterator)

    if save_checkpoints:
        file_name = f'{timestamp}_epoch_{str(epoch)}.pt'
        folder = 'logs/checkpoint/'
        path = os.path.expanduser(folder +  file_name)
        torch.save(model.state_dict(), path)

    print(f"epoch loss {overall_loss}")
    print(
        "========================================================================================================"
    )

    return overall_loss

In [None]:
def inference(model, iterator, device='cpu'):

    model.eval()

    with torch.no_grad():
        for batch in tqdm(iterator):
            
            predictions = predict(model, batch, device, inference=True)

    return predictions

## Петля обучения

In [None]:
%%time
loss = []
loss_eval = []
scores = []

print(timestamp)
print(f'Start training model {str(model)}')
for i in range(epochs):
    loss.append(train(model, training_generator, optimizer, criterion, epoch=i, device=device))
    loss_eval.append(evaluate(model, valid_generator, criterion, epoch=i, device=device, save_checkpoints=True, timestamp=timestamp))

In [None]:
plt.plot(loss, color='red', label='train')
plt.plot(loss_eval, color='blue', label='eval')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('Loss (какой?)')
plt.title('ChataboxModel')
plt.show()