In [1]:
'''
This is a model of Text sentiment classification.
Li Teng
10.11.2021


The process will be done in following step:

1.load config

2.build model

3.load data and clean it.

4.training and recording 

5.testing
'''

'\nThis is a model of Text sentiment classification.\nLi Teng\n10.11.2021\n\n\nThe process will be done in following step:\n\n1.load config\n\n2.build model\n\n3.load data and clean it.\n\n4.training and recording \n\n5.testing\n'

In [2]:
import torch
from config import DefaultConfig
from Text_Cnn import ConvNet, dynamical_padding
import os
from utils import load_flattened_documents, load_datasets
from datasets_preprocessing import clean_datasets, Movie_Classif_Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import wandb

'''
loading config 
'''
Conf = DefaultConfig()

BATCH_SIZE = Conf.batch_size
Lr = Conf.lr
EPOCHS = Conf.epochs
DEVICE = Conf.device

In [3]:
'''
step 2: build model
'''
Text_CNN = ConvNet()

In [4]:
'''
step 3: load data and preprocessing
'''

#load raw datasets
data_root = os.path.join('Data', 'movies')
documents = load_flattened_documents(data_root,None)
documents = clean_datasets(documents)
train, val, test = load_datasets(data_root)
#load Train_Dataset
Train_Dataset = Movie_Classif_Dataset(documents,train)
#load into DataLoader
Loader = DataLoader(dataset = Train_Dataset,
                    batch_size = BATCH_SIZE,
                    shuffle=True,
                    collate_fn=dynamical_padding)

In [5]:
'''
step 4: training and recording 
'''
def train(epochs,model,device,dataloader,Lr):
    '''
    Training model
    '''    
    loss_func = nn.NLLLoss()
    optimizer = torch.optim.SGD(model.parameters(),lr=Lr)
    model.to(device)
    #model.train()
    for e in range(epochs):
        for i,(x,y_hat) in enumerate(dataloader):
            x = x.to(device)
            y_hat = y_hat.to(device)
            y = model(x)
            loss = loss_func(y,y_hat)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i%100 == 0:
                print('epoch:{}, batch:{}, loss:{}'.format(e,i,loss.data))
                wandb.log({"loss": loss})
                # Optional
                wandb.watch(model)
    torch.save(model,'Text_Cnn.pth')

wandb.init(project='Text_Cnn',entity='teng_li')
train(EPOCHS,Text_CNN,DEVICE,Loader,Lr)

wandb: Currently logged in as: teng_li (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.12.7 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


epoch:0, batch:0, loss:0.7643529176712036
epoch:0, batch:100, loss:0.6867783069610596
epoch:0, batch:200, loss:0.6926758289337158
epoch:0, batch:300, loss:0.7586843371391296
epoch:0, batch:400, loss:0.6166043281555176
epoch:0, batch:500, loss:0.6673188209533691
epoch:0, batch:600, loss:0.5911650657653809
epoch:0, batch:700, loss:0.7405256628990173
epoch:1, batch:0, loss:0.6578952074050903
epoch:1, batch:100, loss:0.6113396883010864
epoch:1, batch:200, loss:0.7237201929092407
epoch:1, batch:300, loss:0.6108875274658203
epoch:1, batch:400, loss:0.6910234689712524
epoch:1, batch:500, loss:0.6951044797897339
epoch:1, batch:600, loss:0.6917423605918884
epoch:1, batch:700, loss:0.5566729307174683
epoch:2, batch:0, loss:0.8171804547309875
epoch:2, batch:100, loss:0.579620897769928
epoch:2, batch:200, loss:0.6413412094116211
epoch:2, batch:300, loss:0.6574481725692749
epoch:2, batch:400, loss:0.7596308588981628
epoch:2, batch:500, loss:0.6155985593795776
epoch:2, batch:600, loss:0.469539701938

In [6]:
'''
step 5: validation
'''
#load Val_Dataset
Val_Dataset = Movie_Classif_Dataset(documents,val)
#load into DataLoader
Loader = DataLoader(dataset = Val_Dataset,
                    batch_size = 1,
                    shuffle=True)

def validation(model,val_dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for x,y in val_dataloader:
            y_hat = model.predict(x)
            if y[0] == y_hat:
                correct += 1
            total += 1
    print('correct:',correct)
    print('total:',total)
    print('acc = ',correct/total)
    
validation(Text_CNN,Loader)



correct: 166
total: 200
acc =  0.83
