In [1]:
!pip install transformers



In [2]:
import transformers

In [3]:
batch_size = 10
model_path = 'gpt2_epoch5.bin'
max_seq_len = 300
epochs = 5
data_path = '/content/eda-data.csv'
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
import numpy as np
import re

def remove_source(text):
    cln_text = text
    if '(Source' in cln_text:
        cln_text,_,_ = cln_text.partition('(Source')
    elif '[Written ' in cln_text:
        cln_text,_,_ = cln_text.partition('[Written')
        
    return cln_text

def clean_synopsis(data):
    # removing hentai and kids tags
    data = data[(data.Hentai != 1) & (data.Kids != 1)]
    synopsis = data.synopsis
    synopsis = synopsis.apply(lambda x: str(x))

    # removing very small synopsis
    synopsis = synopsis.apply(lambda x: x if ((len(str(x).strip().split())<=300) and len(str(x).strip().split())>30  ) else -1)
    synopsis = synopsis[synopsis!=-1]
    
    # removing source text
    synopsis = synopsis.apply(lambda x: remove_source(x))
    
    # removing japanese characters
    synopsis = synopsis.apply(lambda x: re.sub("([^\x00-\x7F])+"," ",x))
    
    # remove symbols
    rx = re.compile('[&#/@`)(;<=\'"$%>]')
    synopsis = synopsis.apply(lambda x: rx.sub('',x))
    synopsis = synopsis.apply(lambda x: x.replace('>',""))
    synopsis = synopsis.apply(lambda x: x.replace('`',""))
    synopsis = synopsis.apply(lambda x: x.replace(')',""))
    synopsis = synopsis.apply(lambda x: x.replace('(',""))
    

    # removing adaptation animes (some relevant might get deleted but there aren`t a lot so we wont be affected as much)
    synopsis = synopsis[synopsis.apply(lambda x: 'adaptation' not in str(x).lower())]    
    synopsis = synopsis[synopsis.apply(lambda x: 'music video' not in str(x).lower())]
    synopsis = synopsis[synopsis.apply(lambda x: 'based on' not in str(x).lower())]
    synopsis = synopsis[synopsis.apply(lambda x: 'spin-off' not in str(x).lower())]
    
    return synopsis.reset_index(drop=True)


class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [5]:
import torch

class AnimeDataset():
    def __init__(self,data):
        self.eos_tok = '<|endoftext|>'
        synopsis = clean_synopsis(data)
        synopsis = synopsis.apply(lambda x: str(x) + self.eos_tok)
        self.synopsis = synopsis.tolist()
        self.pad_tok = tokenizer.encode(['<|pad|>'])
    def __getitem__(self,item):
        synopsis = self.synopsis[item]
        tokens = tokenizer.encode(synopsis)
        mask = [1]*len(tokens)
        
        max_len = max_seq_len
        if max_len>len(tokens):
            padding_len = max_len - len(tokens)
            tokens = tokens + self.pad_tok*padding_len
            mask = mask + [0]*padding_len
        else:
            tokens = tokens[:max_len]
            mask = mask[:max_len]
        
        if tokens[-1]!= tokenizer.encode(self.eos_tok)[0]:
            tokens[-1] = tokenizer.encode(self.eos_tok)[0]
        
        return {'ids':torch.tensor(tokens,dtype = torch.long),
                'mask': torch.tensor(mask,dtype = torch.long),
                'og_synpsis':synopsis}
    
        
        
     
    def __len__(self):
        return len(self.synopsis)



In [6]:
from tqdm import tqdm
import torch
import numpy as np


def train_fn(model,dataloader,optimizer,scheduler,device):
    model.train()
    tk0 = tqdm(dataloader, total = len(dataloader), leave = True, position = 0)
    train_loss = AverageMeter()
    losses = []
    for bi,d in enumerate(tk0):
            
        ids = d['ids'].to(device,dtype = torch.long)
        mask = d['mask'].to(device,dtype = torch.long)
        
        loss,out = model(input_ids = ids, labels = ids, attention_mask  = mask)[:2]
        
        train_loss.update(loss.item())    
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        tk0.set_postfix(loss = train_loss.avg)
    return np.mean(losses)        

In [7]:
import pandas as pd
from transformers import AdamW 
from transformers import get_linear_schedule_with_warmup
import torch
def run():
    data = pd.read_csv(data_path)
    dataset = AnimeDataset(data = data)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
    
    device = 'cuda'
    model.to(device)
    
    optimizer = AdamW(model.parameters(),lr = 0.0001,weight_decay = 0.003)    
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=10,num_training_steps = int(len(data)/batch_size * epochs))
    
    best_loss = 111111
    for epoch in range(epochs):
        loss = train_fn(model,dataloader,optimizer,scheduler,device)
        if loss<best_loss:
            best_loss = loss
            torch.save(model.state_dict(),model_path)
        torch.cuda.empty_cache
    

In [8]:
run()

100%|██████████| 731/731 [18:22<00:00,  1.51s/it, loss=1.63]
100%|██████████| 731/731 [18:21<00:00,  1.51s/it, loss=1.46]
100%|██████████| 731/731 [18:21<00:00,  1.51s/it, loss=1.38]
100%|██████████| 731/731 [18:21<00:00,  1.51s/it, loss=1.31]
100%|██████████| 731/731 [18:19<00:00,  1.50s/it, loss=1.25]


In [12]:
def generate_text(input_text,device = 'cuda',max_len = 300):
  pad_tok = tokenizer.encode(['<|pad|>'])[0]
  model.load_state_dict(torch.load(model_path))
  model.to(device)
  model.eval()

  input_ids = tokenizer.encode(input_text)
  mask = [1]*len(input_ids)

  padding_len = max_seq_len - len(input_ids)
  
  input_ids = input_ids #+ pad_tok*padding_len
  mask = mask + [0]*padding_len

  ids = torch.tensor(input_ids,dtype = torch.long).to(device).unsqueeze(0)
  mask = torch.tensor(mask,dtype = torch.long).to(device).unsqueeze(0)
  
  #print(ids[0])
  sample_out = model.generate(ids, min_length = 30,max_length=max_len, pad_token_id=pad_tok,
                              top_p=0.85, early_stopping=True, do_sample=True, num_beams = 5, no_repeat_ngram_size = 2,num_return_sequences=1)
  
  print('Generated Text:\n\n',tokenizer.decode(sample_out[0],skip_special_tokens = True))


In [15]:
generate_text('When the night',device = 'cuda')

Generated Text:

 When the night before a school festival, a mysterious girl suddenly appears in front of the entire school. She claims to be from the future, and she wants to take over the world. To do this, she has to use her powers to transform into a magical girl.
