In [None]:
from torch.utils.data import DataLoader
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd

In [None]:
# Load data
SST = pd.read_pickle("./dataset/SST_train")
input_ids = [x for x in torch.tensor(SST["input_ids"].values.tolist())]
encoder_mask = [x for x in torch.tensor(SST["encoder_mask"].values.tolist())]
label = [x for x in torch.tensor(SST["label"].values.tolist())]

SST_val = pd.read_pickle("./dataset/SST_validation")
input_ids_val = [x for x in torch.tensor(SST_val["input_ids"].values.tolist())]
encoder_mask_val = [x for x in torch.tensor(SST_val["encoder_mask"].values.tolist())]
label_val = [x for x in torch.tensor(SST_val["label"].values.tolist())]

In [None]:
# Create dataloader
sst_dataframe = list(zip(input_ids, encoder_mask, label))
train_iter = DataLoader(sst_dataframe, batch_size=48, shuffle=True, num_workers=4)
sst_val_dataframe = list(zip(input_ids_val, encoder_mask_val, label_val))
val_iter = DataLoader(sst_val_dataframe, batch_size=48, shuffle=True, num_workers=4)

In [None]:
from torch.nn import ModuleList
import copy

# Vocab
id_to_char = ['<PAD>', '<CLS>', '<SEP>', '<MASK>', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '=', '?', '@', \
              '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
char_to_id = {c:i for i,c in enumerate(id_to_char)}
num_spechar = 4
vocab_size = len(id_to_char)-num_spechar
data_max_len = 256


class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embed_size, pad_idx, max_seq_len, drop_prob):
        super(EmbeddingLayer, self).__init__()
        self.max_seq_len = max_seq_len
        self.char_embedding = nn.Embedding(vocab_size, embed_size, padding_idx=pad_idx)
        self.position_embedding = nn.Embedding(max_seq_len, embed_size)
        self.LayerNorm = nn.LayerNorm(embed_size, eps=1e-7)
        self.dropout = nn.Dropout(drop_prob)  # 0.1


    def forward(self, input_ids):
        position_ids = torch.arange(self.max_seq_len, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

        words_embeddings = self.char_embedding(input_ids)
        position_embeddings = self.position_embedding(position_ids)
        
        embeddings = words_embeddings + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings  # (batchSize, sequenceLength, hidden_size)

class CBERT(nn.Module):
    def __init__(self, vocab_size, embed_size, dim_feedforward, num_heads, num_layers, pad_idx):
        super(CBERT, self).__init__()
        self.embedding_layer = EmbeddingLayer(vocab_size, embed_size, pad_idx, data_max_len, 0.1)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=dim_feedforward, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.prediction_layer = nn.Linear(embed_size, vocab_size)
        self.log_softmax = nn.LogSoftmax(dim=2)
    
    def forward(self, x, mask):
        embedded_x = self.embedding_layer(x)
        encoded_x = self.transformer_encoder(embedded_x, src_key_padding_mask=mask)
        #return torch.flatten(encoded_x, start_dim=1)
        #return encoded_x[:,0,:]
        return torch.mean(encoded_x*torch.unsqueeze(mask,-1), 1, False)
    
class CBERT_SA(nn.Module):
    def __init__(self, vocab_size, embed_size, dim_feedforward, num_heads, num_layers, pad_idx, CBERT_PATH, num_class=2):
        super(CBERT_SA, self).__init__()
        self.CBERT = CBERT(vocab_size, embed_size, dim_feedforward, num_heads, num_layers, pad_idx)
        self.prediction_layer = nn.Linear(embed_size, num_class)
        #self.prediction_layer = nn.Sequential(nn.Linear(embed_size*256, num_class))
        #self.prediction_layer = nn.Sequential(nn.Linear(embed_size, 30), nn.ReLU(), nn.Linear(30, num_class))
        self.softmax = nn.Softmax(dim=1)
        
        # init
        checkpoint = torch.load(CBERT_PATH)
        self.CBERT.load_state_dict(checkpoint['model_state_dict'])

        def init_weights(m):
            if type(m) == nn.Linear:
                torch.nn.init.xavier_uniform_(m.weight)
                torch.nn.init.zeros_(m.bias)
        self.prediction_layer.apply(init_weights)
    
    def forward(self, x, mask):
        #print(x.shape)
        encoded_x = self.CBERT(x, mask)
        #print(encoded_x.shape)
        pred = self.prediction_layer(encoded_x)
        #print(pred.shape)
        ret = self.softmax(pred)
        #print(ret)
        return ret

In [None]:
model = CBERT_SA(vocab_size=len(id_to_char),
                 embed_size=768,
                 dim_feedforward=2048,
                 num_heads=12,
                 num_layers=6,
                 pad_idx=char_to_id['<PAD>'],
                 CBERT_PATH = "./models/yelp6_cp12.pt")
model = model.to(device='cuda')

In [None]:
from tqdm import tqdm

def train(model, optimizer, loss_f, train_iter, num_epochs, device='cpu', prnt_intv=1):
    model = model.to(device=device)
    for epoch in range(num_epochs):
        train_loss_sum = torch.tensor([0.0], device=device)
        train_acc_sum = torch.tensor([0.0], device=device)
        n = 0
        #for x, mask, y in tqdm(train_iter):
        for x, mask, y in train_iter:
            x, mask, y = x.to(device=device), mask.to(device=device), y.to(device=device)
            model.train()
            optimizer.zero_grad()
            y_hat = model(x, mask)
            
            loss = loss_f(y_hat, y)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5e-5)
            optimizer.step()
            
            #if not torch.isfinite(loss):
            #    print(x, y, iid, mask)
            with torch.no_grad():
                train_loss_sum += loss.float()
                pred = torch.argmax(y_hat, dim=1)
                train_acc_sum += torch.sum(pred==y)
                n += x.shape[0]
                    
        if (epoch+1)%prnt_intv == 0:
            print("Epoch:%d Loss:%f, TrainAcc:%f"%(epoch+1,train_loss_sum/n,train_acc_sum/n))
            
def evaluate(model, loss_f, val_iter, device='cpu'):
    model = model.to(device=device)
    model.eval()
    with torch.no_grad():
        loss_sum = torch.tensor([0.0], device=device)
        acc_sum = torch.tensor([0.0], device=device)
        n = 0
        
        #for x, mask, y in tqdm(val_iter):
        for x, mask, y in val_iter:
            x, mask, y = x.to(device=device), mask.to(device=device), y.to(device=device)
            y_hat = model(x, mask)            
            loss = loss_f(y_hat, y)
        
            loss_sum += loss.float()
            pred = torch.argmax(y_hat, dim=1)
            acc_sum += torch.sum(pred==y)
            n += x.shape[0]
        
    print("Val_Loss:%f, Val_Acc:%f"%(loss_sum/n, acc_sum/n))

def print_decoded_ids(ids):
    for c in ids:
        if c==0: # <PAD>, <SEP>
            break
        else:
            print(id_to_char[c], end='')
    print()
    
def show_sample(model, data_iter, loss_f, device='cpu'):
    model = model.to(device=device)
    model.eval()
    with torch.no_grad():
        idx=0
        for x, mask, y in data_iter:
            x, mask, y = x.to(device=device), mask.to(device=device), y.to(device=device)
            model.train()
            optimizer.zero_grad()
            y_hat = model(x, mask)            
            loss = loss_f(y_hat, y)

            pred = torch.argmax(y_hat, dim=1)
            for i,p,t in zip(x, pred, y):
                print("INPUT:", end=' ')
                print_decoded_ids(i)
                print("TARGET:", t.item(), end=' ')
                print("PREDICTION:", p.item())
            print("loss:", loss)
            pred = torch.argmax(y_hat, dim=1)
            print("acc:",torch.sum(pred==y)/x.shape[0])
            print("="*80)

            if idx==10: break
            else: idx+=1

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay=0.0)
loss_f = nn.CrossEntropyLoss()
for i in range(5):
    train(model, optimizer, loss_f, train_iter, num_epochs=1, device='cuda', prnt_intv=1)
    evaluate(model, loss_f, val_iter, device='cuda')

In [None]:
evaluate(model, loss_f, val_iter, device='cuda')

In [None]:
show_sample(model, val_iter, loss_f, device='cuda')

In [None]:
PATH = "./models/SST_sample.pt"
torch.save({
    'model_state_dict' : model.state_dict(),
}, PATH)