In [44]:
import torch
import torch.nn.functional as F
import pandas as pd

from torch.utils.data import DataLoader, Dataset, TensorDataset, Sampler
from keras.preprocessing.sequence import pad_sequences
from pytorch_pretrained_bert import BertAdam
from pytorch_pretrained_bert import BertForSequenceClassification

In [2]:
data = pd.read_csv('./data/bert_data.csv')

In [3]:
data

Unnamed: 0,review,sentiment,ko_bt,org_ids,ko_ids
0,One of the other reviewers has mentioned that ...,positive,One of the other reviewers mentioned that you ...,"[101, 2028, 1997, 1996, 2060, 15814, 2038, 385...","[101, 2028, 1997, 1996, 2060, 15814, 3855, 200..."
1,A wonderful little production. <br /><br />The...,positive,Nice little production. <br /><br /> The shoot...,"[101, 1037, 6919, 2210, 2537, 1012, 1026, 7987...","[101, 3835, 2210, 2537, 1012, 1026, 7987, 1013..."
2,I thought this was a wonderful way to spend ti...,positive,I thought this was a great way to spend time o...,"[101, 1045, 2245, 2023, 2001, 1037, 6919, 2126...","[101, 1045, 2245, 2023, 2001, 1037, 2307, 2126..."
3,Basically there's a family where a little boy ...,negative,Basically there is a family where a little boy...,"[101, 10468, 2045, 1005, 1055, 1037, 2155, 207...","[101, 10468, 2045, 2003, 1037, 2155, 2073, 103..."
4,"Petter Mattei's ""Love in the Time of Money"" is...",positive,Petter Mattei&#39;s &quot;Love for Money&#39;s...,"[101, 9004, 3334, 4717, 7416, 1005, 1055, 1000...","[101, 9004, 3334, 4717, 7416, 1004, 1001, 4464..."
5,"Probably my all-time favorite movie, a story o...",positive,"Perhaps the story of the best film ever, selfl...","[101, 2763, 2026, 2035, 1011, 2051, 5440, 3185...","[101, 3383, 1996, 2466, 1997, 1996, 2190, 2143..."
6,I sure would like to see a resurrection of a u...,positive,I would like to see the revival of the Seahunt...,"[101, 1045, 2469, 2052, 2066, 2000, 2156, 1037...","[101, 1045, 2052, 2066, 2000, 2156, 1996, 6308..."
7,"This show was an amazing, fresh & innovative i...",negative,This show was a stunning fresh and innovative ...,"[101, 2023, 2265, 2001, 2019, 6429, 1010, 4840...","[101, 2023, 2265, 2001, 1037, 14726, 4840, 199..."
8,Encouraged by the positive comments about this...,negative,I was looking forward to seeing this movie bec...,"[101, 6628, 2011, 1996, 3893, 7928, 2055, 2023...","[101, 1045, 2001, 2559, 2830, 2000, 3773, 2023..."
9,If you like original gut wrenching laughter yo...,positive,"If you like good old-fashioned laughs, you&#39...","[101, 2065, 2017, 2066, 2434, 9535, 16255, 845...","[101, 2065, 2017, 2066, 2204, 2214, 1011, 1340..."


In [53]:
def convert_ids_to_tensor(lst):
    org_id_lst = [eval(i)[:512] for i in lst]
    org_ids = pad_sequences(org_id_lst,maxlen=512,padding='post')
    org_tensor = torch.from_numpy(org_ids)
    return org_tensor

In [152]:
y = torch.tensor(data.sentiment.replace({'negative':0,'positive':1}).values).type(torch.float32).view(-1,1)

In [110]:
org_tensor = convert_ids_to_tensor(data.org_ids.values)
trans_tensor = convert_ids_to_tensor(data.ko_ids.values)

In [163]:
training_td = TensorDataset(org_tensor[:20].type(torch.int64),y[:20])
aux_td = TensorDataset(org_tensor.type(torch.int64),trans_tensor.type(torch.int64),y)

In [280]:
test_td = TensorDataset(org_tensor[-10000:].type(torch.int64),y[-10000:])

In [283]:
len(test_td)

10000

In [47]:
class FixlengthSampler(Sampler):
    def __init__(self, data_source, length=None):
        self.data_source = data_source
        self.length = length if length is not None else len(self.data_source)
    def __iter__(self,):
        return iter(torch.randint(low=0,high=len(self.data_source),size=(self.length,)))
    def __len__(self): 
        return self.length

In [304]:
training_steps = 10
aux_batch_size = 8
training_batch_size = 16
test_batch_size = 16

training_sampler = FixlengthSampler(training_td,length=training_steps*training_batch_size)
training_dl = DataLoader(training_td,batch_size=training_batch_size,sampler=training_sampler)

aux_sampler = FixlengthSampler(aux_td,length=training_steps*aux_batch_size)
aux_dl = DataLoader(aux_td,batch_size=aux_batch_size,sampler=aux_sampler)

test_dl = DataLoader(test_td,batch_size=16)

In [305]:
### quick test
aux_iter = iter(aux_dl)
bax1,bax2,bay = next(aux_iter)

In [172]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=1)
model = model.to(device)

In [300]:
def kl_div(bl1,bl2):
    bl1_prob,bl2_prob = torch.sigmoid(bl1), torch.sigmoid(bl2)
    return (bl1_prob*(bl1_prob/bl2_prob).log()).mean()

In [314]:
# this is the same example in wiki
P = torch.Tensor([0.36, 0.48, 0.16])
Q = torch.Tensor([0.333, 0.333, 0.333])
(P * (P / Q).log()).mean()
# tensor(0.0863), 10.2 µs ± 508
F.kl_div(Q.log(), P, None, None, 'batchmean')
# tensor(0.0863), 14.1 µs ± 408 ns

tensor(0.0288)

In [320]:
def inverse_sig(x):
    return torch.log(x/(1-x))

In [321]:
kl_div(inverse_sig(P),inverse_sig(Q))

tensor(0.0288)

In [322]:
kl_div(torch.tensor(1.9),torch.tensor(5.8))

tensor(-0.1186)

In [323]:
F.kl_div(torch.tensor(5.8).sigmoid().log(), torch.tensor(1.9).sigmoid(), reduction='batchmean')

tensor(-0.1186)

In [276]:
opt = BertAdam(model.parameters(),
               lr=2e-5,
               warmup=0.02, 
               t_total=training_steps,
              )

In [None]:
def eval_model(model):
    acc = 0.0
    for bx,by in test_dl:
        attention_mask = bx > 0 
        bx,by = bx.to(device),by.to(device)
        bl = model(bx,attention_mask=attention_mask)
        bacc = ((bl >= 0).astype(torch.float32) == by).sum()
        acc += bacc
    acc = acc/len(text_dl)
    return acc

In [294]:
uda = True
lamda = 0.2
training_iter = iter(training_dl)
aux_iter = iter(aux_dl)
# eval_per_steps = int(training_steps/10)
eval_per_steps = 100

model = model.train()
for i in range(training_steps):
    bx,by = next(training_iter)
    bx,by = bx.to(device),by.to(device)
    
    if uda: 
        bax1,bax2,bay = next(aux_iter)
        bax1,bax2,bay = bax1.to(device),bax2.to(device),bay.to(device)
        
    progress = torch.tensor(i/training_steps)
    ita = cal_ita(progress)
    
    attention_mask = bx > 0
    bl = model(bx,attention_mask=attention_mask)
    if uda:
        bal1,bal2 = model(bax1),model(bax2)
        
        loss = tsa_loss(bl,by,ita)
        aux_loss = kl_div(bal1,bal2)
        loss = loss + lamda * aux_loss
    else:
#         loss = tsa_loss(bl,by)
        loss = torch.binary_cross_entropy_with_logits(bl,by)
    
    opt.zero_grad()
    loss.backward()
    opt.step()
    
#     loss =  
    if i % eval_per_steps == 0:
        model = model.eval()
        acc = eval_model(model)
        model = model.train()
        
    break
    
    print(ita)

TypeError: tsa_loss() missing 1 required positional argument: 'ita'

In [298]:
loss = tsa_loss(bl,by,ita)
loss

tensor(0.7853, grad_fn=<MulBackward0>)

In [301]:
aux_loss = kl_div(bl,)
aux_loss

tensor(-0.0634, grad_fn=<MeanBackward1>)

In [None]:
torch.kl_div(by)

In [302]:
bl

tensor([[0.0854],
        [0.2212],
        [0.2692],
        [0.2029],
        [0.1858],
        [0.2984],
        [0.1517],
        [0.2177],
        [0.1325],
        [0.1521],
        [0.2172],
        [0.3126],
        [0.3012],
        [0.0492],
        [0.3645],
        [0.2692]], grad_fn=<AddmmBackward>)

In [303]:
by

tensor([[1.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.],
        [0.]])

In [269]:
tsa_loss

<function __main__.tsa_loss(logits, labels, ita)>

In [202]:
(P * (P / Q).log()).sum()

tensor(0.0863)

In [97]:
logits = torch.tensor([[10.,-3.]])
logits

tensor([[10., -3.]])

In [148]:
labels = torch.tensor([[1,0]])
labels

tensor([[1, 0]])

In [149]:
labels.dtype

torch.int64

In [103]:
tsa_loss(logits,labels,ita)

tensor(0.)

In [104]:
probs = torch.sigmoid(logits)
probs 

tensor([[1.0000, 0.0474]])

In [105]:
confi_probs = 1 - torch.abs(probs - labels)
confi_probs

tensor([[1.0000, 0.9526]])

In [106]:
confi_probs

tensor([[1.0000, 0.9526]])

In [107]:
masks = (confi_probs <= ita).type(torch.float32)
masks

tensor([[0., 0.]])

In [108]:
F.binary_cross_entropy(probs,labels,reduction='none') * masks

tensor([[0., 0.]])

In [109]:
F.binary_cross_entropy_with_logits(logits,labels,reduction='none')

tensor([[4.5418e-05, 4.8587e-02]])

In [296]:
def tsa_loss(logits,labels,ita):
    probs = torch.sigmoid(logits)
    confi_probs = 1 - torch.abs(probs - labels)
    masks = (confi_probs <= ita).type(torch.float32)
#     print(masks)
    loss = torch.mean(F.binary_cross_entropy(probs,labels,reduction='none') * masks)
    loss = 1/masks.mean()*loss
    return loss
def cal_ita(progress):
    return torch.exp((progress-1)*5)*(1-1/2)+1/2

In [265]:
tsa_loss(logits,labels.type(torch.float32),ita)

tensor([[0., 1.]])


tensor(3.0486)

In [266]:
labels

tensor([[1, 0]])

In [267]:
logits

tensor([[10.,  3.]])

In [None]:
F.