In [2]:
import os
import pdb
import argparse
from dataclasses import dataclass, field
from typing import Optional
from collections import defaultdict

import torch
from torch.nn.utils.rnn import pad_sequence

import numpy as np
from tqdm import tqdm, trange

from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    AutoConfig,
    AdamW
)

from preprocess import data2dataset
from utils import collate_fn_style, collate_fn_style_test, compute_acc

In [None]:
!pip install transformers wandb


In [None]:
import wandb
wandb.login()

hyperparameter_defaults = dict(
    dropout = 0.1,
    batch_size = 64,
    learning_rate = 1e-5,
    epochs = 3,
    seed = 42,
    architecture = "BERT",
    classes = 2,
    tr_loss_check = 50,
    )

wandb.init(config=hyperparameter_defaults, project="baseline-LR", entity="goorm_nlp_project_1", reinit=True)
config = wandb.config
wandb.run.name = 'baseline-LR[]'

# preprocess

In [3]:
train_dataset, dev_dataset, test_dataset, test_df = data2dataset()

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=config.batch_size,
                                           shuffle=True, collate_fn=collate_fn_style,
                                           pin_memory=True, num_workers=2)
dev_loader = torch.utils.data.DataLoader(dev_dataset, batch_size=config.batch_size,
                                         shuffle=False, collate_fn=collate_fn_style,
                                         num_workers=2)

Downloading: 100%|██████████| 226k/226k [00:00<00:00, 632kB/s] 
Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 9.17kB/s]
Downloading: 100%|██████████| 570/570 [00:00<00:00, 172kB/s]


FileNotFoundError: [Errno 2] No such file or directory: 'sentiment.train.1'

In [None]:
# random seed
np.random.seed(config.seed)
torch.manual_seed(config.seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.to(device)
wandb.watch(model)

In [None]:
model.train()
optimizer = AdamW(model.parameters(), lr=config.learning_rate)

# Train

In [None]:
lowest_valid_loss = 9999.
for epoch in range(config.epochs):
    with tqdm(train_loader, unit="batch") as tepoch:
        for iteration, (input_ids, attention_mask, token_type_ids, position_ids, labels) in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch}")
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            position_ids = position_ids.to(device)
            labels = labels.to(device, dtype=torch.long)
            
            optimizer.zero_grad()

            output = model(input_ids=input_ids,
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids,
                           position_ids=position_ids,
                           labels=labels)

            loss = output.loss
            loss.backward()
            
            optimizer.step()
            tepoch.set_postfix(loss=loss.item())
            
            if iteration % config.tr_loss_check == 0:
              wandb.log({"train_loss" : loss})
              
            if iteration != 0 and iteration % int(len(train_loader) / 5) == 0:
                # Evaluate the model five times per epoch
                with torch.no_grad():
                    model.eval()
                    valid_losses = []
                    predictions = []
                    target_labels = []
                    for input_ids, attention_mask, token_type_ids, position_ids, labels in tqdm(dev_loader,
                                                                                                desc='Eval',
                                                                                                position=1,
                                                                                                leave=None):
                        input_ids = input_ids.to(device)
                        attention_mask = attention_mask.to(device)
                        token_type_ids = token_type_ids.to(device)
                        position_ids = position_ids.to(device)
                        labels = labels.to(device, dtype=torch.long)

                        output = model(input_ids=input_ids,
                                       attention_mask=attention_mask,
                                       token_type_ids=token_type_ids,
                                       position_ids=position_ids,
                                       labels=labels)

                        logits = output.logits
                        loss = output.loss
                        valid_losses.append(loss.item())

                        batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
                        batch_labels = [int(example) for example in labels]

                        predictions += batch_predictions
                        target_labels += batch_labels

                acc = compute_acc(predictions, target_labels)
                valid_loss = sum(valid_losses) / len(valid_losses)
                wandb.log({"val_acc":acc, "val_loss" : valid_loss})
                if lowest_valid_loss > valid_loss:
                    print('Acc for model which have lower valid loss: ', acc)
                    torch.save(model.state_dict(), "./pytorch_model.bin")
wandb.finish()

In [None]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.batch_size,
                                          shuffle=False, collate_fn=collate_fn_style_test,
                                          num_workers=2)

In [None]:
with torch.no_grad():
    model.eval()
    predictions = []
    for input_ids, attention_mask, token_type_ids, position_ids in tqdm(test_loader,
                                                                        desc='Test',
                                                                        position=1,
                                                                        leave=None):

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        position_ids = position_ids.to(device)

        output = model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids,
                       position_ids=position_ids)

        logits = output.logits
        batch_predictions = [0 if example[0] > example[1] else 1 for example in logits]
        predictions += batch_predictions

In [None]:
test_df['Category'] = predictions
test_df.to_csv('submission.csv', index=False)