In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter
from dataset import Articles
from utils import CLASSES, load_voc
import numpy as np
from model import TopicClassifier


In [2]:
# environment settings
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(torch.cuda.get_device_name(torch.cuda.current_device()))
writer = SummaryWriter("runs/topic_classifier.01")

# data, emb = init()    for glove
# train_dataset = Articles("train.csv", len(CLASSES),data, emb)    for glove
# valid_dataset = Articles("valid.csv", len(CLASSES),data, emb)    for glove
# dataset 
train_dataset = Articles("train.csv", len(CLASSES))
valid_dataset = Articles("valid.csv", len(CLASSES))

# dataloader
training = torch.utils.data.DataLoader(train_dataset, batch_size=50,
                                       shuffle=True, drop_last=True)
validation = torch.utils.data.DataLoader(valid_dataset, batch_size=50,
                                         shuffle=True, drop_last=True)


In [3]:
# model
vocabulary = len(load_voc())
model = TopicClassifier(vocabulary, len(CLASSES)).to(device)
model.cuda()
# on a une classe par instance au lieu d'une liste dans le dataset
# criterion = torch.nn.BCELoss()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

batch_size_1 = len(training)
batch_size_2 = len(validation)
train_loss = 0
validation_loss = 0

In [4]:
for epoch in range(200):
    print("starting epoch:", epoch, "loss:", "training", train_loss, "validation", validation_loss)

    train_loss = 0
    validation_loss = 0

    for batch, (text, topic) in enumerate(training):

        if (batch % 20) == 0:
            print("batch:", batch)

        optimizer.zero_grad()
        y = model(text.to(device))

        loss = criterion(y.to(device), topic.squeeze().to(device))
        train_loss += loss.item()

        loss.backward(retain_graph=True)
        optimizer.step()
        del loss
        model.cell[0].detach_()
        model.cell[1].detach_()

    if (epoch % 10) == 0:
        torch.save({"state_dict": model.state_dict(),
                    "epoch": epoch},
                   str(epoch) + "_topic.pth")

    for batch, (text, topic) in enumerate(validation):

        if (batch % 10) == 0:
            print("batch:", batch)

        y = model(text.to(device))

        loss = criterion(y.to(device), topic.squeeze().to(device))
        validation_loss += loss.item()

        del loss
        model.cell[0].detach_()
        model.cell[1].detach_()


    train_loss = train_loss / batch_size_1
    validation_loss = validation_loss / batch_size_2

    writer.add_scalars("loss", {"training": train_loss, "validation": validation_loss}, epoch)

In [5]:
# save the results
torch.save({"state_dict": model.state_dict(),
            "epoch": 200},
           "topicClassifier.pth")
writer.close()