In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Добавляем путь к основной папке проекта, чтобы иметь возможность делать импорт из src
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from collections import defaultdict
from pathlib import Path
import random
import time
from typing import Any, Dict, List

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from pytorch_metric_learning import losses
from pytorch_metric_learning.distances import CosineSimilarity
import torch
import torch.nn as nn
import tqdm
import wandb

from src.data.nn_preprocessing import clean_company_name_string
from src.data.lstm_dataloader import CompanyNameDataLoader
from src.models.lstm import LSTMNetwork

In [4]:
train_data = pd.read_csv('../data/processed/train_companies_for_metric_learning.csv')
valid_data = pd.read_csv('../data/processed/valid_companies_for_metric_learning.csv')


# Цикл обучения

In [5]:
contrastive_loss = losses.ContrastiveLoss(pos_margin=1, neg_margin=0, distance=CosineSimilarity())
output_path = Path('../models')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
lr = 1e-3
emb_dim = 80
hidden_size = 80
num_layers = 3
dropout = 0.0
batch_size = 16
proj_size = 0
exp_name = f'LSTM_train_{emb_dim}_{hidden_size}_{num_layers}_{dropout}_{proj_size}_stop_words'

config = dict (
    learning_rate = lr,
    architecture = "LSTM",
    batch_size = batch_size,
    hidden_size = hidden_size,
    num_layers = num_layers
)

wandb.init(
  project='dl_cnp_23',
  name=exp_name,
  config=config,
)

train_dataloader = CompanyNameDataLoader(train_data, shuffle=True, preprocessing=clean_company_name_string)
valid_dataloader = CompanyNameDataLoader(valid_data, shuffle=False, preprocessing=clean_company_name_string)

net = LSTMNetwork(emb_dim=emb_dim,
                  hidden_size=hidden_size,
                  num_layers=num_layers,
                  dropout=dropout,
                  proj_size=proj_size).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
best_valid_loss = np.inf

for epoch in range(20):
    print(f'Epoch {epoch}')
    train_losses = []
    for batch in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()
        pairs = zip(batch[0], batch[1])
        sorted_pair = sorted(pairs, key=lambda x: len(x[0]), reverse=True)
        x, y = zip(*sorted_pair)
        pred = net(x)
        loss = contrastive_loss(pred, torch.Tensor(y).to(device))
        loss.backward()
        train_losses.append(loss.item())
        optimizer.step()
    train_loss = np.mean(train_losses)
    time.sleep(0.5)
    valid_losses = []
    for batch in tqdm.tqdm(valid_dataloader):
        pairs = zip(batch[0], batch[1])
        sorted_pair = sorted(pairs, key=lambda x: len(x[0]), reverse=True)
        x, y = zip(*sorted_pair)
        with torch.no_grad():
            pred = net(x)
            loss = contrastive_loss(pred, torch.Tensor(y).to(device))
        valid_losses.append(loss.item())
    valid_loss = np.mean(valid_losses)
    print(f'train loss: {train_loss}, valid_loss: {valid_loss}')
    wandb.log({'train_loss': train_loss, 'valid_loss': valid_loss}, step=epoch)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(net.state_dict(), output_path / f'{exp_name}_best.pth')
    time.sleep(0.5)

wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mmgurevich[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0


1096it [00:11, 99.13it/s]
107it [00:00, 300.56it/s]


train loss: 0.43472485411504325, valid_loss: 0.35685104075993335
Epoch 1


1096it [00:11, 93.65it/s] 
107it [00:00, 293.35it/s]


train loss: 0.2288205150827548, valid_loss: 0.27868598434969644
Epoch 2


1096it [00:11, 96.15it/s] 
107it [00:00, 299.51it/s]


train loss: 0.17273355318049827, valid_loss: 0.22877967143518346
Epoch 3


1096it [00:11, 95.37it/s]
107it [00:00, 299.41it/s]


train loss: 0.14117190577973524, valid_loss: 0.24864941206978303
Epoch 4


1096it [00:11, 94.33it/s]
107it [00:00, 299.64it/s]


train loss: 0.12051873148782655, valid_loss: 0.233514806464593
Epoch 5


1096it [00:11, 95.73it/s]
107it [00:00, 321.93it/s]


train loss: 0.10450605075624958, valid_loss: 0.22053107100530206
Epoch 6


1096it [00:11, 97.30it/s] 
107it [00:00, 299.97it/s]


train loss: 0.09762612885810829, valid_loss: 0.20949080776179504
Epoch 7


1096it [00:11, 96.85it/s]
107it [00:00, 308.17it/s]


train loss: 0.09511689951258412, valid_loss: 0.24701234931996632
Epoch 8


1096it [00:11, 91.85it/s]
107it [00:00, 307.25it/s]


train loss: 0.09049658676285219, valid_loss: 0.21381217187757942
Epoch 9


1096it [00:11, 93.49it/s]
107it [00:00, 321.15it/s]


train loss: 0.08743265605965854, valid_loss: 0.19336907130396255
Epoch 10


1096it [00:11, 96.55it/s]
107it [00:00, 305.88it/s]


train loss: 0.0921248224197762, valid_loss: 0.1843334998125994
Epoch 11


1096it [00:11, 96.16it/s] 
107it [00:00, 303.26it/s]


train loss: 0.07823188918754016, valid_loss: 0.22670681636658646
Epoch 12


1096it [00:11, 95.95it/s]
107it [00:00, 321.32it/s]


train loss: 0.07956297962477184, valid_loss: 0.2100659932497297
Epoch 13


1096it [00:11, 97.12it/s]
107it [00:00, 298.76it/s]


train loss: 0.08889294068083566, valid_loss: 0.22763970418437166
Epoch 14


1096it [00:11, 94.07it/s]
107it [00:00, 299.01it/s]


train loss: 0.08123524709158352, valid_loss: 0.21077068090421436
Epoch 15


1096it [00:11, 96.19it/s] 
107it [00:00, 300.38it/s]


train loss: 0.07701680902476647, valid_loss: 0.21848768125251197
Epoch 16


1096it [00:11, 93.14it/s]
107it [00:00, 320.63it/s]


train loss: 0.07434990599649344, valid_loss: 0.1891397076537049
Epoch 17


1096it [00:11, 94.24it/s]
107it [00:00, 301.62it/s]


train loss: 0.0686096438345958, valid_loss: 0.17697967766462086
Epoch 18


1096it [00:11, 96.15it/s]
107it [00:00, 298.85it/s]


train loss: 0.0678343432727871, valid_loss: 0.20008771242922016
Epoch 19


1096it [00:11, 95.51it/s]
107it [00:00, 308.11it/s]


train loss: 0.06578497065535625, valid_loss: 0.22071510028294244


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_loss,█▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▅▃▄▃▃▂▄▂▂▁▃▂▃▂▃▁▁▂▃

0,1
train_loss,0.06578
valid_loss,0.22072
