In [9]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [10]:
import pandas as pd
import torch
!pip install transformers
from transformers import BertTokenizer
# from keras.preprocessing.sequence import pad_sequences
from torch.nn import functional as F

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/13/33/ffb67897a6985a7b7d8e5e7878c3628678f553634bd3836404fef06ef19b/transformers-2.5.1-py3-none-any.whl (499kB)
[K     |████████████████████████████████| 501kB 12.4MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 44.7MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)
[K     |████████████████████████████████| 870kB 38.7MB/s 
Collecting tokenizers==0.5.2
[?25l  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)
[K     |███

In [0]:
import pickle
with open('/content/drive/My Drive/data.pkl','rb') as f:
  df=pickle.load(f)

In [0]:
MAX_LENGTH=len(df.iloc[0]['encoding'])
tag_index={'[PAD]':0,'<s>':1,'B-AG':2,'I-AG':3,'B-TG':4,'I-TG':5,'O':6,'</s>':7}
df['labels']=df['labels'].apply(lambda x: [tag_index[lbl] for lbl in x])

In [0]:
from torch.utils.data import Dataset
class SequenceDataset(Dataset):
  def __init__(self,df):
    self.df=df

  def __len__(self):
    return len(self.df)

  def __getitem__(self,index):
    return torch.tensor(self.df.iloc[index]['encoding']),torch.tensor(self.df.iloc[index]['attn_mask']),torch.tensor(self.df.iloc[index]['labels'])

In [0]:
from torch.utils.data import DataLoader
import numpy as np
msk = np.random.rand(len(df)) < 0.8
train=df[msk]
val=df[~msk]
train_set=SequenceDataset(train)
val_set=SequenceDataset(val)
train_loader=DataLoader(train_set, batch_size = 16)
val_loader = DataLoader(val_set, batch_size = 16)

In [0]:
import torch.nn as nn
from transformers import BertModel
class Encoder(nn.Module):
    def __init__(self, freeze_bert = True):
        super(Encoder, self).__init__()
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')

        for p in self.bert_layer.parameters():
            p.requires_grad = False

    def forward(self, seq, attn_masks):
        cont_reps, _ = self.bert_layer(seq, attention_mask = attn_masks)
        return cont_reps[:,0]        

In [0]:
class Decoder(nn.Module):
    def __init__(self,vocab_size, hidden_size,output_size, dropout_p=0.1):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.vocab_size=vocab_size
        self.output_size = output_size
        self.dropout = nn.Dropout(0.1)
        self.embedding = nn.Embedding(self.vocab_size,self.hidden_size)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size,batch_first=True)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, hidden,input):
       embedded = self.embedding(input).view(hidden.shape[0], 1, -1)
       embedded = self.dropout(embedded)
       output, hidden = self.gru(embedded, hidden.permute(1,0,2).contiguous())
       output = F.log_softmax(self.out(output), dim=2)
       return output, hidden.permute(1,0,2)   

In [0]:
class AttnDecoderRNN(nn.Module):
    def __init__(self,vocab_size, hidden_size,output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.vocab_size=vocab_size
        # self.embedding_size=embedding_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.vocab_size,self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(0.1)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size,batch_first=True)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        # print(input.shape,hidden.shape,encoder_outputs.shape)
        embedded = self.embedding(input).view(encoder_outputs.shape[0], 1, -1)
        embedded = self.dropout(embedded)
        
        hidden=hidden.expand(encoder_outputs.shape[0],1,-1)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded, hidden), 2)), dim=2)
        # print(attn_weights.shape,encoder_outputs.shape)
        # attn_applied=attn_weights.permute(0,2,1)*encoder_outputs
        attn_applied = torch.bmm(attn_weights,
                                 encoder_outputs)

        output = torch.cat((embedded, attn_applied),dim=2)
        output = self.attn_combine(output)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden.permute(1,0,2).contiguous())
        output = F.log_softmax(self.out(output), dim=2)
        return output, hidden.permute(1,0,2), attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [0]:
import torch.optim as optim
import tqdm
import os
import pickle
# embedding_size=100
hidden_size=768
# bert_model=BertModel.from_pretrained('bert-base-uncased')
vocab_size=len(tag_index)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder=Encoder().to(device)

if(os.path.exists('/content/drive/My Drive/BERT-SEQ-Tagger/encoder.pt')):
    encoder.load_state_dict(torch.load('/content/drive/My Drive/BERT-SEQ-Tagger/encoder.pt'))

decoder=Decoder(vocab_size,hidden_size,len(tag_index)).to(device)
if(os.path.exists('/content/drive/My Drive/BERT-SEQ-Tagger/decoder.pt')):
    decoder.load_state_dict(torch.load('/content/drive/My Drive/BERT-SEQ-Tagger/decoder.pt'))

criterion = nn.NLLLoss(ignore_index=tag_index['[PAD]'])
# enc_optimizer = optim.Adam(encoder.parameters(), lr = 2e-5)
dec_optimizer = optim.Adam(decoder.parameters(), lr = 1e-5)

# training_loss=[]
val_losses=[]
if(os.path.exists('/content/drive/My Drive/BERT-SEQ-Tagger/val_losses.pkl')):
  with open('/content/drive/My Drive/BERT-SEQ-Tagger/val_losses.pkl','rb') as f:
    val_losses=pickle.load(f)

for _e in range(200):
    train_loss=0
    for t, (seq, attn_mask, labels) in enumerate(train_loader):
        # data_batch = sort_batch_by_len(data_dict)
        batch_size=seq.shape[0]
        seq=seq.to(device)
        attn_mask=attn_mask.to(device)
        labels =labels.to(device) #torch.tensor(data_batch).to(device)
                
        # enc_optimizer.zero_grad()
        dec_optimizer.zero_grad()
        encoder_output=encoder(seq,attn_mask)        
        decoder_input = torch.tensor([batch_size*[tag_index['<s>']]], device=device).view(-1,1)
        decoder_hidden=encoder_output.view(batch_size,1,-1)
        labels= torch.cat((labels,torch.tensor(batch_size*[tag_index['</s>']], device=device).view(-1,1)),dim=1)
        loss=0
        for di in range(labels.shape[1]):
          decoder_output,decoder_hidden=decoder(decoder_hidden,decoder_input)
          # print(decoder_output.squeeze(0).shape)
          loss += criterion(decoder_output.view(encoder_output.shape[0],-1), labels[:,di])
          train_loss+=loss.data.item()
          decoder_input = labels[:,di]
        loss.backward()   
        # enc_optimizer.step()
        dec_optimizer.step()
    train_loss=train_loss/len(train)
    # print(train_loss)    
    val_loss=0
    for t, (seq, attn_mask, labels) in enumerate(val_loader):
        seq=seq.to(device)
        attn_mask=attn_mask.to(device)
        labels =labels.to(device) #torch.tensor(data_batch).to(device)
        batch_size=seq.shape[0]
        # enc_optimizer.zero_grad()
        dec_optimizer.zero_grad()
        encoder_output=encoder(seq,attn_mask)
        decoder_input = torch.tensor([batch_size*[tag_index['<s>']]], device=device).view(-1,1)
        decoder_hidden=encoder_output.view(batch_size,1,-1)
        labels= torch.cat((labels,torch.tensor(batch_size*[tag_index['</s>']], device=device).view(-1,1)),dim=1)
        loss=0
        for di in range(labels.shape[1]):
          decoder_output,decoder_hidden=decoder(decoder_hidden,decoder_input)
          # print(decoder_output.squeeze(0).shape)
          loss += criterion(decoder_output.view(encoder_output.shape[0],-1), labels[:,di])
          decoder_input = labels[:,di]
          # _, top_idx = decoder_output.data.topk(1)
          # decoder_input = top_idx.view(-1)
          val_loss+=loss.data.item()
    val_loss=val_loss/len(val)
    if(len(val_losses)>0 and val_loss<min(val_losses)):
      torch.save(encoder.state_dict(), '/content/drive/My Drive/BERT-SEQ-Tagger/encoder.pt') 
      torch.save(decoder.state_dict(), '/content/drive/My Drive/BERT-SEQ-Tagger/decoder.pt')  
    val_losses.append(val_loss)      
    print('training loss:{} validation loss:{}'.format(train_loss,val_loss))       

HBox(children=(IntProgress(value=0, description='Downloading', max=361, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=440473133, style=ProgressStyle(description_…


training loss:39.35991589239764 validation loss:43.826463055253406
training loss:39.331333402254394 validation loss:43.890565284560154
training loss:39.36235692006367 validation loss:43.89929524733911
training loss:39.31968021992948 validation loss:43.91838281839052
training loss:39.271852087648185 validation loss:43.814784980863806
training loss:39.271614001349946 validation loss:43.899170158460315
training loss:39.24298890185159 validation loss:43.95398934998723
training loss:39.2680966125463 validation loss:43.94822983885601
training loss:39.230188481294974 validation loss:43.90578533481066
training loss:39.185819162640946 validation loss:43.94691698201447
training loss:39.15019943164688 validation loss:43.8743080613835
training loss:39.1979382640497 validation loss:43.86892881180962
training loss:39.21436924076175 validation loss:44.00979377068462
training loss:39.164274306146574 validation loss:43.99739631310058
training loss:39.14845025125 validation loss:44.01422585664488
train

In [0]:
import pickle
with open('/content/drive/My Drive/BERT-SEQ-Tagger/val_losses.pkl','wb') as f:
  pickle.dump(val_losses,f)

In [0]:
encoder=Encoder().to(device)
encoder.load_state_dict(torch.load('/content/drive/My Drive/BERT-SEQ-Tagger/encoder.pt'))
decoder=AttnDecoderRNN(vocab_size,hidden_size,len(tag_index)).to(device)
decoder.load_state_dict(torch.load('/content/drive/My Drive/BERT-SEQ-Tagger/decoder.pt'))
seq=torch.tensor(val.iloc[0]['encoding']).view(1,-1).to(device)
attn_mask=torch.tensor(val.iloc[0]['attn_mask']).view(1,-1).to(device)
labels=torch.tensor(val.iloc[0]['labels']).view(1,-1).to(device)

encoder_output=encoder(seq,attn_mask)
decoder_input = torch.tensor([encoder_output.shape[0]*[tag_index['<s>']]], device=device).view(-1,1)
decoder_hidden=decoder.initHidden()

labels= torch.cat((labels,torch.tensor([encoder_output.shape[0]*[tag_index['<s>']]], device=device).view(-1,1)),dim=1)
loss=0
for di in range(labels.shape[1]):
  decoder_output,decoder_hidden,_=decoder(decoder_input,decoder_hidden,encoder_output)
  # print(decoder_output.squeeze(0).shape)
  # loss += criterion(decoder_output.view(encoder_output.shape[0],-1), labels[:,di])
  _, top_idx = decoder_output.data.topk(1)
  decoder_input = top_idx.view(-1)
  print(decoder_input)        