<a href="https://colab.research.google.com/github/AoShuang92/Doti/blob/main/SD_Transformer_TS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
import torch.nn.functional as F
#system
from nltk.translate.bleu_score import sentence_bleu
import numpy as np
import os
import math
import pandas as pd
import re
import string
from nltk.stem import WordNetLemmatizer
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet
from collections import Counter
import json
from torchtext.vocab import Vectors, GloVe
from nltk.util import ngrams
from nltk.tokenize import word_tokenize
import torchtext
from torchtext.data.utils import get_tokenizer
from nltk.metrics import accuracy, precision, recall, f_measure
from nltk.translate.meteor_score import single_meteor_score
from torchtext.legacy import data
import warnings
warnings.filterwarnings('ignore')


def seed_everything(seed=20):
  #random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


In [None]:
!pip install "nltk==3.4.5"
!pip install rouge-score

Collecting nltk==3.4.5
[?25l  Downloading https://files.pythonhosted.org/packages/f6/1d/d925cfb4f324ede997f6d47bea4d9babba51b49e87a767c170b77005889d/nltk-3.4.5.zip (1.5MB)
[K     |████████████████████████████████| 1.5MB 5.9MB/s 
Building wheels for collected packages: nltk
  Building wheel for nltk (setup.py) ... [?25l[?25hdone
  Created wheel for nltk: filename=nltk-3.4.5-cp37-none-any.whl size=1449907 sha256=7fb91471dd16c98ee62ba4817243aafc7e4c6846a9fb547e854cc45d1c372b74
  Stored in directory: /root/.cache/pip/wheels/96/86/f6/68ab24c23f207c0077381a5e3904b2815136b879538a24b483
Successfully built nltk
Installing collected packages: nltk
  Found existing installation: nltk 3.2.5
    Uninstalling nltk-3.2.5:
      Successfully uninstalled nltk-3.2.5
Successfully installed nltk-3.4.5
Collecting rouge-score
  Downloading https://files.pythonhosted.org/packages/1f/56/a81022436c08b9405a5247b71635394d44fe7e1dbedc4b28c740e09c2840/rouge_score-0.0.4-py2.py3-none-any.whl
Installing collected

In [None]:
max_length = 35
train_dir = "/content/drive/MyDrive/calibration_project/medical_dialogue_system/combined_qa_train_ID.csv"
test_dir = "/content/drive/MyDrive/calibration_project/medical_dialogue_system/combined_qa_test_200_ID.csv"
batch_size = 4


def remove_unnecessary(text):
    #remove_URL
    url = re.compile(r'https?://\S+|www\.\S+')
    text = url.sub('', text)

    #remove_html
    html = re.compile(r'<.*?>')
    text = html.sub('', text)

    #remove @
    text = re.sub('@[^\s]+','',text)

    #remove_emoji
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    text = emoji_pattern.sub(r'', text)

    #Removes integers 
    text = ''.join([i for i in text if not i.isdigit()])         
    
    #remove_punct
    table = str.maketrans('', '', string.punctuation)
    text = text.translate(table)

    #Replaces contractions from a string to their equivalents 
    contraction_patterns = [(r'won\'t', 'will not'), (r'can\'t', 'cannot'), (r'i\'m', 'i am'), 
                            (r'ain\'t', 'is not'), (r'(\w+)\'ll', '\g<1> will'), (r'(\w+)n\'t', '\g<1> not'),
                            (r'(\w+)\'ve', '\g<1> have'), (r'(\w+)\'s', '\g<1> is'), (r'(\w+)\'re', '\g<1> are'),
                            (r'(\w+)\'d', '\g<1> would'), (r'&', 'and'), (r'dammit', 'damn it'), 
                            (r'dont', 'do not'), (r'wont', 'will not')]
    
    patterns = [(re.compile(regex), repl) for (regex, repl) in contraction_patterns]
    for (pattern, repl) in patterns:
        text, _= re.subn(pattern, repl, text)

    #lemmatize_sentence
    sentence_words = text.split(' ')
    new_sentence_words = list()
    
    for sentence_word in sentence_words:
        sentence_word = sentence_word.replace('#', '')
        new_sentence_word = WordNetLemmatizer().lemmatize(sentence_word.lower(), wordnet.VERB)
        new_sentence_words.append(new_sentence_word)
        
    new_sentence = ' '.join(new_sentence_words)
    new_sentence = new_sentence.strip()

    return new_sentence.lower()

def prepare_csv(train,test):
    # idx = np.arange(df_train.shape[0])    
    # np.random.shuffle(idx)
    # val_size = int(len(idx) * val_ratio)
    if not os.path.exists('cache'): # cache is tem memory file 
        os.makedirs('cache')
    
    train_temp = train[['Question', 'Answer']].to_csv(
        'cache/dataset_train.csv', index=True)
    
    test_temp = test[['Question', 'Answer']].to_csv(
        'cache/dataset_val.csv', index=True) 
    return  train_temp,  test_temp

def get_iterator(dataset, batch_size, train=True,
                 shuffle=True, repeat=False, device=None): 
    dataset_iter = data.Iterator(
        dataset, batch_size=batch_size, device=device,
        train=train, shuffle=shuffle, repeat=repeat,
        sort=False)  
    return dataset_iter

def get_dataset(fix_length=max_length, lower=False, vectors=None,train_dir = train_dir, test_dir = test_dir, batch_size=batch_size, device=None): 
    train = pd.read_csv(train_dir,error_bad_lines=False,sep=",")
    test =  pd.read_csv(test_dir,error_bad_lines=False, sep=",")
    train['Question'] = train['Question'].apply(lambda x: remove_unnecessary(x))
    train['Answer'] = train['Answer'].apply(lambda x: remove_unnecessary(x))
    test['Question'] = test['Question'].apply(lambda x: remove_unnecessary(x))
    
    test['Answer'] = test['Answer'].apply(lambda x: remove_unnecessary(x))
    train_temp,  test_temp = prepare_csv(train,test)
    if vectors is not None:
        lower=True

    TEXT = data.Field(tokenize=get_tokenizer("spacy"),init_token='<sos>',eos_token='<eos>',lower=True,batch_first=True, 
                      fix_length=fix_length)
    ID = data.Field(use_vocab=False, sequential=False, dtype=torch.float16)  
    train_temps = data.TabularDataset(
        path='/content/cache/dataset_train.csv', format='csv', skip_header=True,
        fields=[("ID",ID),('Question', TEXT), ('Answer', TEXT)]) 
    test_temps = data.TabularDataset(
        path='/content/cache/dataset_val.csv', format='csv', skip_header=True,
        fields=[("ID",ID),('Question', TEXT), ('Answer', TEXT)]) 

    TEXT.build_vocab(train_temps,test_temps)#, vectors=GloVe(name='6B', dim=300))
    ID.build_vocab(train_temps, test_temps)
    word_embeddings = TEXT.vocab.vectors
    vocab_size = len(TEXT.vocab)
    ntokens = len(TEXT.vocab.stoi)
    print("vocab_size_and_ntokens:",vocab_size,ntokens)
    train_loader = get_iterator(train_temps, batch_size=batch_size, 
                                train=True, shuffle=True,
                                repeat=False,device=None)
    test_loader = get_iterator(test_temps, batch_size=1, 
                            train=False, shuffle=False,
                            repeat=False, device=None)
    print('Train samples:%d'%(len(train_temps)), 'Valid samples:%d'%(len(test_temps)),'Train minibatch nb:%d'%(len(train_loader)),
            'Valid minibatch nb:%d'%(len(test_loader)))
    return vocab_size, word_embeddings, ntokens, train_loader, test_loader, TEXT

In [None]:
vocab_size, word_embeddings, ntokens, train_loader, test_loader, TEXT = get_dataset(fix_length=max_length,train_dir = train_dir, test_dir = test_dir, batch_size=batch_size)

vocab_size_and_ntokens: 1938 1938
Train samples:1001 Valid samples:200 Train minibatch nb:251 Valid minibatch nb:200


In [None]:
def create_masks(question, reply_input):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)         # (batch_size, 1, 1, max_words)
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    
    return question_mask, reply_input_mask

In [None]:
class Embeddings(nn.Module):
    """
    Implements embeddings of the words and adds their positional encodings. 
    """
    def __init__(self, vocab_size, d_model, max_len = max_length):
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = self.create_positinal_encoding(max_len, self.d_model)
        self.dropout = nn.Dropout(0.1)
        
    def create_positinal_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len):   # for each position of the word
            for i in range(0, d_model, 2):   # for each dimension of the each position
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)   # include the batch size
        return pe
        
    def forward(self, encoded_words):
        embedding = self.embed(encoded_words) * math.sqrt(self.d_model)
        embedding += self.pe[:, :embedding.size(1)]   # pe will automatically be expanded with the same batch size as encoded_words
        embedding = self.dropout(embedding)
        return embedding

class MultiHeadAttention(nn.Module):
    
    def __init__(self, heads, d_model):
        
        super(MultiHeadAttention, self).__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, 512)
        mask of shape: (batch_size, 1, 1, max_words)
        """
        # (batch_size, max_len, 512)
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        # (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
        #scores = torch.matmul(query, key.permute(2,1,0,0)) / math.sqrt(query.size(-1))
        scores = scores.masked_fill(mask == 0, -1e9)    # (batch_size, h, max_len, max_len)
        weights = F.softmax(scores, dim = -1)           # (batch_size, h, max_len, max_len)
        weights = self.dropout(weights)
        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)
        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
        context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
        # (batch_size, max_len, h * d_k)
        interacted = self.concat(context)
        return interacted

class FeedForward(nn.Module):

    def __init__(self, d_model, middle_dim = 2048):
        super(FeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, mask):
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        interacted = self.layernorm(interacted + embeddings)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

class DecoderLayer(nn.Module):
    
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.src_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, embeddings, encoded, src_mask, target_mask):
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
        query = self.layernorm(query + embeddings)
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        interacted = self.layernorm(interacted + query)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        return decoded


class Transformer(nn.Module):    
    def __init__(self, d_model, heads, num_layers, ntokens):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.vocab_size = ntokens
        self.embed = Embeddings(self.vocab_size, d_model)#max_len
        self.embed_dec = Embeddings(self.vocab_size, d_model,max_length)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
        self.logit = nn.Linear(d_model, self.vocab_size)   
        
    def encode(self, src_words, src_mask):
        src_embeddings = self.embed(src_words)
        for layer in self.encoder:
            src_embeddings = layer(src_embeddings, src_mask)
        return src_embeddings
    
    def decode(self, target_words, target_mask, src_embeddings, src_mask):
        tgt_embeddings = self.embed_dec(target_words)
        for layer in self.decoder:
            tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
        return tgt_embeddings
        
    def forward(self, src_words, src_mask, target_words, target_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        out = F.log_softmax(self.logit(decoded), dim = 2)
        return out

class AdamWarmup:
    
    def __init__(self, model_size, warmup_steps, optimizer):
        
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
        
    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()


In [None]:
def train(train_loader, transformer, criterion, epoch):    
    transformer.train()
    sum_loss = 0
    count = 0
    
    for i, pair in enumerate(train_loader):

        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # Create mask and add dimensions
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out = transformer(question, question_mask, reply_input, reply_input_mask)
        reply_target = reply_target.reshape(-1)
        loss = criterion(out.view(-1, ntokens), reply_target)
        
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()
        

def valid (test_loader,transformer): 
    all_blue = []
    word_map = TEXT.vocab.stoi
    rev_word_map = TEXT.vocab.itos

    transformer.eval()
    for i, pair in enumerate(test_loader):
    
        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out = transformer(question, question_mask, reply_input, reply_input_mask)
        _, next = torch.max(out, dim = 2)# 2x51
        for idx in range(next.shape[0]):
            pred_sentence= prediction_ids2sentence(next[idx]).split()
            gt=prediction_ids2sentence(reply_target[idx]).split()
            BLEU_1 = sentence_bleu(([gt]), pred_sentence, weights=(1, 0, 0, 0))
            all_blue.append(BLEU_1)
    
    return np.mean(all_blue)

def evaluate(transformer, question, question_mask, max_len):
    """
    Performs Greedy Decoding with a batch size of 1
    """
    #rev_word_map = {v: k for k, v in word_embeddings.items()}
    word_map = TEXT.vocab.stoi
    rev_word_map = TEXT.vocab.itos
    transformer.eval()
    start_token = word_map['<sos>']
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device)
    next_word = -22
    while next_word != word_map['<eos>']:
    #for step in range(max_len - 1):
        size = words.shape[1]
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
        decoded = transformer.decode(words, target_mask, encoded, question_mask)
        predictions = transformer.logit(decoded[:, -1])
        _, next_word = torch.max(predictions, dim = 1)
        next_word = next_word.item()
        if next_word == word_map['<eos>'] or words.shape[1]==(max_len+1):
            break
        words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1)   # (1,step+2)
        
    # Construct Sentence
    if words.dim() == 2:
        words = words.squeeze(0)
        words = words.tolist()
        
    sen_idx = [w for w in words if w not in {word_map['<sos>']}]
    sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
    

    return sentence

def prediction_ids2sentence(pred_ids):
    #rev_word_map = {v: k for k, v in word_embeddings.items()}
    word_map = TEXT.vocab.stoi
    rev_word_map = TEXT.vocab.itos
    #_, next = torch.max(out, dim = 2)
    sen_idx = []
    for w in pred_ids:
        if w == word_map['<eos>']:
            break
        sen_idx.append(w)
    #print(sen_idx)
    sentence = ' '.join([rev_word_map[int(sen_idx[k])] for k in range(len(sen_idx))])
    return sentence

from nltk.metrics import accuracy, precision, recall, f_measure

def evaluate_matrics(transformer,test_loader):
    sum_loss = 0
    all_blue1 = []
    all_meteor = []
    all_rouge = []
    word_map = TEXT.vocab.stoi
    rev_word_map = TEXT.vocab.itos

    transformer.eval()
    for i, pair in enumerate(test_loader):
        
        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out = transformer(question, question_mask, reply_input, reply_input_mask)
        reply_target_mask = reply_target.reshape(-1)
        loss = criterion(out.view(-1, ntokens), reply_target_mask)
        sum_loss += loss.item()
        _, next = torch.max(out, dim = 2)# 2x51
        
        for idx in range(next.shape[0]):
            pred_sentence= prediction_ids2sentence(next[idx]).split()
            gt=prediction_ids2sentence(reply_target[idx]).split()
            BLEU_1 = sentence_bleu(([gt]), pred_sentence, weights=(1, 0, 0, 0))
            reference_set = set(gt)
            test_set = set(pred_sentence)
            meteor = single_meteor_score( str(gt), str(pred_sentence))
            scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
            scoress = scorer.score(str(pred_sentence), str(gt))

            all_blue1.append(BLEU_1)
            all_meteor.append(meteor)
            all_rouge.append(scoress['rougeL'][2])
           
    bleu_score_1 = np.mean(all_blue1)
    met = np.mean(all_meteor)
    ppl = math.exp(sum_loss/i)
    rouge = np.mean(all_rouge)

    print("BLEU_SCORE1:",bleu_score_1, "Rouge:",rouge, "Meteor:",met,"PPL:",ppl)

In [None]:
import torch
from torch import nn, optim
from torch.nn import functional as F

class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, inputs, input_mask, targets, targets_mask):
        logits = self.model(inputs, input_mask, targets, targets_mask)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size())
        return logits / temperature

    # This function probably should live outside of this class, but whatever
    def set_temperature(self, valid_loader):
        self.cuda()
        ece_criterion = _ECELoss().cuda()
        nll_criterion = nn.CrossEntropyLoss().cuda()

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for i, pair in enumerate(valid_loader):
    
                input = pair.Question.cuda()
                label = pair.Answer.cuda()
            #for input, label in valid_loader:
                input = input.cuda()
                label = label.cuda()
                label = label[:, 1:]
                input_mask, label_mask = create_masks(input, label)
                logits = self.model(input, input_mask, label, label_mask)
                logits_list.append(logits)
                labels_list.append(label)
            logits = torch.cat(logits_list).cuda()
            labels = torch.cat(labels_list).cuda()
            

        # Next: optimize the temperature w.r.t. NLL
        init_temp = self.temperature.clone()
        optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        def eval():
            labels_loss = labels.reshape(-1)
            loss = nll_criterion(self.temperature_scale(logits.view(-1, ntokens)), labels_loss)
            loss.backward()
            return loss
        optimizer.step(eval)

        # CalculateECE after temperature scaling
        labels_loss = labels.reshape(-1)
        after_temperature_ece = ece_criterion(self.temperature_scale(logits.view(-1,ntokens )), labels_loss).item()
        print('Initial temperature: %.3f, Optimal temperature: %.3f' % (init_temp, self.temperature.item()))
        #print('Initial temperature: %.3f, Optimal temperature: %.3f' % (init_temp, after_temperature_ece))
        return self
        
class _ECELoss(nn.Module):
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

def evaluation(model, test_loader):
    
    logits_list = []
    labels_list = []
    all_blue = []
    
    model.eval()
    with torch.no_grad():
        
        for i, pair in enumerate(test_loader):
    
            inputs = pair.Question
            targets = pair.Answer
            inputs, targets = inputs.to(device), targets.to(device)
            targets = targets[:, 1:]
            input_mask, targets_mask = create_masks(inputs, targets)
            outputs = model(inputs, input_mask, targets, targets_mask)
            
            logits_list.append(outputs)
            labels_list.append(targets)
            
            _, predicted = outputs.max(2)
        
            for idx in range(predicted.shape[0]):
                
                pred_sentence= prediction_ids2sentence(predicted[idx]).split()
                gt=prediction_ids2sentence(targets[idx]).split()
                BLEU_1 = sentence_bleu(([gt]), pred_sentence, weights=(1, 0, 0, 0))
                reference_set = set(gt)
                test_set = set(pred_sentence)
                all_blue.append(BLEU_1)

    logits_all = torch.cat(logits_list).cuda()
    labels_all = torch.cat(labels_list).cuda()
    
    return np.mean(all_blue),logits_all, labels_all

getting optimal T

In [None]:
seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens)
model = model.to(device)
adam_optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

model.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_models_baseline.pth'))
ece_criterion = _ECELoss().to(device)

bleu, logits_all, labels_all = evaluation(model, test_loader)
logits_all = logits_all.view(-1,ntokens)
labels_all = labels_all.view(-1)
temperature_ece = ece_criterion(logits_all, labels_all).item()
print('Before TS- bleu:%.3f, bef ece:%.5f'%(bleu,temperature_ece))

model_ts = ModelWithTemperature(model)
model_ts.set_temperature(test_loader)
bleu, logits_all, labels_all = evaluation(model_ts, test_loader)
logits_all = logits_all.view(-1,ntokens)
labels_all = labels_all.view(-1)
temperature_ece = ece_criterion(logits_all, labels_all).item()
print('After TS- bleu:%.3f,aft ece:%.5f'%(bleu,temperature_ece))

Before TS- bleu:0.405, bef ece:0.38368
Initial temperature: 1.500, Optimal temperature: 5.025
After TS- bleu:0.405,aft ece:0.33059


Training with SD, T = 1.5

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftTarget(nn.Module):
	'''
	Distilling the Knowledge in a Neural Network
	https://arxiv.org/pdf/1503.02531.pdf
	'''
	def __init__(self, T):
		super(SoftTarget, self).__init__()
		self.T = T

	def forward(self, out_s, out_t):
		loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
						F.softmax(out_t/self.T, dim=1),
						reduction='batchmean') * self.T * self.T

		return loss

def train_sd(train_loader, transformer_t, transformer_s, criterion, criterionKD, transformer_optimizer, epoch):    
    transformer_s.train()
    transformer_t.eval()
    sum_loss = 0
    count = 0
    
    for i, pair in enumerate(train_loader): 
        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # Create mask and add dimensions
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out_s = transformer_s(question, question_mask, reply_input, reply_input_mask)
        with torch.no_grad():
            out_t = transformer_t(question, question_mask, reply_input, reply_input_mask)
        reply_target = reply_target.reshape(-1)
        loss_cls = criterion(out_s.view(-1, ntokens), reply_target)
        kd_loss = criterionKD(out_s.view(-1, ntokens), out_t.detach().view(-1, ntokens))
        loss = loss_cls + kd_loss
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()

seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
adam_optimizer = torch.optim.Adam(model_s.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

model_t.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_models_baseline.pth'))

T = 1.5
criterionKD = SoftTarget(T)
best_blue = 0
best_epoch = 0
for epoch in range(100):
    
    train_sd(train_loader, model_t, model_s, criterion, criterionKD, transformer_optimizer, epoch)
    blue_score = valid (test_loader, model_s)
    if blue_score > best_blue:
        best_blue = blue_score
        best_epoch = epoch
        torch.save(model_s.state_dict(), '/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_1.5t.pth')
    print('cur epoch:%d, cur blue:%.5f, best epoch:%d, best blue:%.5f'%(epoch,blue_score, best_epoch, best_blue))



cur epoch:0, cur blue:0.00153, best epoch:0, best blue:0.00153
cur epoch:1, cur blue:0.09313, best epoch:1, best blue:0.09313
cur epoch:2, cur blue:0.15070, best epoch:2, best blue:0.15070
cur epoch:3, cur blue:0.17288, best epoch:3, best blue:0.17288
cur epoch:4, cur blue:0.21768, best epoch:4, best blue:0.21768
cur epoch:5, cur blue:0.24554, best epoch:5, best blue:0.24554
cur epoch:6, cur blue:0.26875, best epoch:6, best blue:0.26875
cur epoch:7, cur blue:0.32112, best epoch:7, best blue:0.32112
cur epoch:8, cur blue:0.33757, best epoch:8, best blue:0.33757
cur epoch:9, cur blue:0.34034, best epoch:9, best blue:0.34034
cur epoch:10, cur blue:0.35039, best epoch:10, best blue:0.35039
cur epoch:11, cur blue:0.35494, best epoch:11, best blue:0.35494
cur epoch:12, cur blue:0.33527, best epoch:11, best blue:0.35494
cur epoch:13, cur blue:0.35988, best epoch:13, best blue:0.35988
cur epoch:14, cur blue:0.34959, best epoch:13, best blue:0.35988
cur epoch:15, cur blue:0.32483, best epoch:13

In [None]:
 from rouge_score import rouge_scorer
evaluate_matrics(model_s,test_loader)

BLEU_SCORE1: 0.43820982960671756 Rouge: 0.45582540667987315 Meteor: 0.4231837980578683 PPL: 10.35602329891361


Training with SD, T=2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftTarget(nn.Module):
	'''
	Distilling the Knowledge in a Neural Network
	https://arxiv.org/pdf/1503.02531.pdf
	'''
	def __init__(self, T):
		super(SoftTarget, self).__init__()
		self.T = T

	def forward(self, out_s, out_t):
		loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
						F.softmax(out_t/self.T, dim=1),
						reduction='batchmean') * self.T * self.T

		return loss

def train_sd(train_loader, transformer_t, transformer_s, criterion, criterionKD, transformer_optimizer, epoch):    
    transformer_s.train()
    transformer_t.eval()
    sum_loss = 0
    count = 0
    
    for i, pair in enumerate(train_loader): 
        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # Create mask and add dimensions
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out_s = transformer_s(question, question_mask, reply_input, reply_input_mask)
        with torch.no_grad():
            out_t = transformer_t(question, question_mask, reply_input, reply_input_mask)
        reply_target = reply_target.reshape(-1)
        loss_cls = criterion(out_s.view(-1, ntokens), reply_target)
        kd_loss = criterionKD(out_s.view(-1, ntokens), out_t.detach().view(-1, ntokens))
        loss = loss_cls + kd_loss
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()

seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t2 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s2 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
adam_optimizer = torch.optim.Adam(model_s2.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

model_t2.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_models_baseline.pth'))

T = 2
criterionKD = SoftTarget(T)
best_blue = 0
best_epoch = 0
for epoch in range(100):
    
    train_sd(train_loader, model_t2, model_s2, criterion, criterionKD, transformer_optimizer, epoch)
    blue_score = valid (test_loader, model_s2)
    if blue_score > best_blue:
        best_blue = blue_score
        best_epoch = epoch
        torch.save(model_s2.state_dict(), '/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_2t.pth')
    print('cur epoch:%d, cur blue:%.5f, best epoch:%d, best blue:%.5f'%(epoch,blue_score, best_epoch, best_blue))

cur epoch:0, cur blue:0.00002, best epoch:0, best blue:0.00002
cur epoch:1, cur blue:0.11241, best epoch:1, best blue:0.11241
cur epoch:2, cur blue:0.14806, best epoch:2, best blue:0.14806
cur epoch:3, cur blue:0.16120, best epoch:3, best blue:0.16120
cur epoch:4, cur blue:0.18181, best epoch:4, best blue:0.18181
cur epoch:5, cur blue:0.20640, best epoch:5, best blue:0.20640
cur epoch:6, cur blue:0.27457, best epoch:6, best blue:0.27457
cur epoch:7, cur blue:0.28708, best epoch:7, best blue:0.28708
cur epoch:8, cur blue:0.31303, best epoch:8, best blue:0.31303
cur epoch:9, cur blue:0.34107, best epoch:9, best blue:0.34107
cur epoch:10, cur blue:0.31539, best epoch:9, best blue:0.34107
cur epoch:11, cur blue:0.31423, best epoch:9, best blue:0.34107
cur epoch:12, cur blue:0.33945, best epoch:9, best blue:0.34107
cur epoch:13, cur blue:0.34015, best epoch:9, best blue:0.34107
cur epoch:14, cur blue:0.35060, best epoch:14, best blue:0.35060
cur epoch:15, cur blue:0.33898, best epoch:14, be

In [None]:
from rouge_score import rouge_scorer
model_t2 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s2 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s2.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_2t.pth'))
evaluate_matrics(model_s2,test_loader)

BLEU_SCORE1: 0.4477079312147275 Rouge: 0.46757194816556946 Meteor: 0.42746513264813046 PPL: 10.15113923168798


Training with SD, T=3

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftTarget(nn.Module):
	'''
	Distilling the Knowledge in a Neural Network
	https://arxiv.org/pdf/1503.02531.pdf
	'''
	def __init__(self, T):
		super(SoftTarget, self).__init__()
		self.T = T

	def forward(self, out_s, out_t):
		loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
						F.softmax(out_t/self.T, dim=1),
						reduction='batchmean') * self.T * self.T

		return loss

def train_sd(train_loader, transformer_t, transformer_s, criterion, criterionKD, transformer_optimizer, epoch):    
    transformer_s.train()
    transformer_t.eval()
    sum_loss = 0
    count = 0
    
    for i, pair in enumerate(train_loader): 
        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # Create mask and add dimensions
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out_s = transformer_s(question, question_mask, reply_input, reply_input_mask)
        with torch.no_grad():
            out_t = transformer_t(question, question_mask, reply_input, reply_input_mask)
        reply_target = reply_target.reshape(-1)
        loss_cls = criterion(out_s.view(-1, ntokens), reply_target)
        kd_loss = criterionKD(out_s.view(-1, ntokens), out_t.detach().view(-1, ntokens))
        loss = loss_cls + kd_loss
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()

seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t3 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s3 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
adam_optimizer = torch.optim.Adam(model_s3.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

model_t3.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_models_baseline.pth'))

T = 3
criterionKD = SoftTarget(T)
best_blue = 0
best_epoch = 0
for epoch in range(100):
    
    train_sd(train_loader, model_t3, model_s3, criterion, criterionKD, transformer_optimizer, epoch)
    blue_score = valid (test_loader, model_s3)
    if blue_score > best_blue:
        best_blue = blue_score
        best_epoch = epoch
        torch.save(model_s3.state_dict(), '/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_3t.pth')
    print('cur epoch:%d, cur blue:%.5f, best epoch:%d, best blue:%.5f'%(epoch,blue_score, best_epoch, best_blue))

cur epoch:0, cur blue:0.00000, best epoch:0, best blue:0.00000
cur epoch:1, cur blue:0.11092, best epoch:1, best blue:0.11092
cur epoch:2, cur blue:0.13236, best epoch:2, best blue:0.13236
cur epoch:3, cur blue:0.15837, best epoch:3, best blue:0.15837
cur epoch:4, cur blue:0.18281, best epoch:4, best blue:0.18281
cur epoch:5, cur blue:0.22478, best epoch:5, best blue:0.22478
cur epoch:6, cur blue:0.25175, best epoch:6, best blue:0.25175
cur epoch:7, cur blue:0.26997, best epoch:7, best blue:0.26997
cur epoch:8, cur blue:0.29078, best epoch:8, best blue:0.29078
cur epoch:9, cur blue:0.32251, best epoch:9, best blue:0.32251
cur epoch:10, cur blue:0.31240, best epoch:9, best blue:0.32251
cur epoch:11, cur blue:0.33043, best epoch:11, best blue:0.33043
cur epoch:12, cur blue:0.32563, best epoch:11, best blue:0.33043
cur epoch:13, cur blue:0.34783, best epoch:13, best blue:0.34783
cur epoch:14, cur blue:0.31138, best epoch:13, best blue:0.34783
cur epoch:15, cur blue:0.32509, best epoch:13,

In [None]:
from rouge_score import rouge_scorer
model_t3 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s3 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s3.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_3t.pth'))
evaluate_matrics(model_s3,test_loader)

BLEU_SCORE1: 0.4465045968261427 Rouge: 0.46306013556599646 Meteor: 0.432026245843939 PPL: 10.197486452813704


Traning with SD, T=4

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftTarget(nn.Module):
	'''
	Distilling the Knowledge in a Neural Network
	https://arxiv.org/pdf/1503.02531.pdf
	'''
	def __init__(self, T):
		super(SoftTarget, self).__init__()
		self.T = T

	def forward(self, out_s, out_t):
		loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
						F.softmax(out_t/self.T, dim=1),
						reduction='batchmean') * self.T * self.T

		return loss

def train_sd(train_loader, transformer_t, transformer_s, criterion, criterionKD, transformer_optimizer, epoch):    
    transformer_s.train()
    transformer_t.eval()
    sum_loss = 0
    count = 0
    
    for i, pair in enumerate(train_loader): 
        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # Create mask and add dimensions
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out_s = transformer_s(question, question_mask, reply_input, reply_input_mask)
        with torch.no_grad():
            out_t = transformer_t(question, question_mask, reply_input, reply_input_mask)
        reply_target = reply_target.reshape(-1)
        loss_cls = criterion(out_s.view(-1, ntokens), reply_target)
        kd_loss = criterionKD(out_s.view(-1, ntokens), out_t.detach().view(-1, ntokens))
        loss = loss_cls + kd_loss
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()

seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t4 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s4 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
adam_optimizer = torch.optim.Adam(model_s4.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

model_t4.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_models_baseline.pth'))

T = 4
criterionKD = SoftTarget(T)
best_blue = 0
best_epoch = 0
for epoch in range(100):
    
    train_sd(train_loader, model_t4, model_s4, criterion, criterionKD, transformer_optimizer, epoch)
    blue_score = valid (test_loader, model_s4)
    if blue_score > best_blue:
        best_blue = blue_score
        best_epoch = epoch
        torch.save(model_s4.state_dict(), '/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_4t.pth')
    print('cur epoch:%d, cur blue:%.5f, best epoch:%d, best blue:%.5f'%(epoch,blue_score, best_epoch, best_blue))

cur epoch:0, cur blue:0.00000, best epoch:0, best blue:0.00000
cur epoch:1, cur blue:0.09470, best epoch:1, best blue:0.09470
cur epoch:2, cur blue:0.13577, best epoch:2, best blue:0.13577
cur epoch:3, cur blue:0.15288, best epoch:3, best blue:0.15288
cur epoch:4, cur blue:0.17612, best epoch:4, best blue:0.17612
cur epoch:5, cur blue:0.22725, best epoch:5, best blue:0.22725
cur epoch:6, cur blue:0.24488, best epoch:6, best blue:0.24488
cur epoch:7, cur blue:0.27504, best epoch:7, best blue:0.27504
cur epoch:8, cur blue:0.29651, best epoch:8, best blue:0.29651
cur epoch:9, cur blue:0.32312, best epoch:9, best blue:0.32312
cur epoch:10, cur blue:0.30938, best epoch:9, best blue:0.32312
cur epoch:11, cur blue:0.33677, best epoch:11, best blue:0.33677
cur epoch:12, cur blue:0.34729, best epoch:12, best blue:0.34729
cur epoch:13, cur blue:0.34584, best epoch:12, best blue:0.34729
cur epoch:14, cur blue:0.32401, best epoch:12, best blue:0.34729
cur epoch:15, cur blue:0.29179, best epoch:12,

In [None]:
from rouge_score import rouge_scorer
model_t4 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s4 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s4.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_4t.pth'))
evaluate_matrics(model_s4,test_loader)

BLEU_SCORE1: 0.43973040628596327 Rouge: 0.4591312274684645 Meteor: 0.4262116411243918 PPL: 10.301880829149088


Training with SD, T=5

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftTarget(nn.Module):
	'''
	Distilling the Knowledge in a Neural Network
	https://arxiv.org/pdf/1503.02531.pdf
	'''
	def __init__(self, T):
		super(SoftTarget, self).__init__()
		self.T = T

	def forward(self, out_s, out_t):
		loss = F.kl_div(F.log_softmax(out_s/self.T, dim=1),
						F.softmax(out_t/self.T, dim=1),
						reduction='batchmean') * self.T * self.T

		return loss

def train_sd(train_loader, transformer_t, transformer_s, criterion, criterionKD, transformer_optimizer, epoch):    
    transformer_s.train()
    transformer_t.eval()
    sum_loss = 0
    count = 0
    
    for i, pair in enumerate(train_loader): 
        question = pair.Question.to(device)
        reply = pair.Answer.to(device)
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # Create mask and add dimensions
        question_mask, reply_input_mask = create_masks(question, reply_input)
        out_s = transformer_s(question, question_mask, reply_input, reply_input_mask)
        with torch.no_grad():
            out_t = transformer_t(question, question_mask, reply_input, reply_input_mask)
        reply_target = reply_target.reshape(-1)
        loss_cls = criterion(out_s.view(-1, ntokens), reply_target)
        kd_loss = criterionKD(out_s.view(-1, ntokens), out_t.detach().view(-1, ntokens))
        loss = loss_cls + kd_loss
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()

seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t5 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
model_s5 = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens).to(device)
adam_optimizer = torch.optim.Adam(model_s5.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

model_t5.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_models_baseline.pth'))

T = 5
criterionKD = SoftTarget(T)
best_blue = 0
best_epoch = 0
for epoch in range(100):
    
    train_sd(train_loader, model_t5, model_s5, criterion, criterionKD, transformer_optimizer, epoch)
    blue_score = valid (test_loader, model_s5)
    if blue_score > best_blue:
        best_blue = blue_score
        best_epoch = epoch
        torch.save(model_s5.state_dict(), '/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_5t.pth')
    print('cur epoch:%d, cur blue:%.5f, best epoch:%d, best blue:%.5f'%(epoch,blue_score, best_epoch, best_blue))

cur epoch:0, cur blue:0.00001, best epoch:0, best blue:0.00001
cur epoch:1, cur blue:0.05770, best epoch:1, best blue:0.05770
cur epoch:2, cur blue:0.14118, best epoch:2, best blue:0.14118
cur epoch:3, cur blue:0.14310, best epoch:3, best blue:0.14310
cur epoch:4, cur blue:0.17985, best epoch:4, best blue:0.17985
cur epoch:5, cur blue:0.22287, best epoch:5, best blue:0.22287
cur epoch:6, cur blue:0.26217, best epoch:6, best blue:0.26217
cur epoch:7, cur blue:0.27164, best epoch:7, best blue:0.27164
cur epoch:8, cur blue:0.27115, best epoch:7, best blue:0.27164
cur epoch:9, cur blue:0.31345, best epoch:9, best blue:0.31345
cur epoch:10, cur blue:0.33722, best epoch:10, best blue:0.33722
cur epoch:11, cur blue:0.32965, best epoch:10, best blue:0.33722
cur epoch:12, cur blue:0.33971, best epoch:12, best blue:0.33971
cur epoch:13, cur blue:0.33675, best epoch:12, best blue:0.33971
cur epoch:14, cur blue:0.34395, best epoch:14, best blue:0.34395
cur epoch:15, cur blue:0.31852, best epoch:14

In [None]:
from rouge_score import rouge_scorer
model_s5.load_state_dict(torch.load('/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_chatbot_models_sd_5t.pth'))
evaluate_matrics(model_s5,test_loader)

BLEU_SCORE1: 0.44316121415895837 Rouge: 0.4596136337718651 Meteor: 0.4286404674318712 PPL: 9.125769864393439


In [None]:
seed_everything()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, ntokens = ntokens)
model = model.to(device)
adam_optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)

best_blue = 0
best_epoch = 0
for epoch in range(100):
    train(train_loader, model, criterion, epoch)
    blue_score = valid (test_loader,model)
    if blue_score > best_blue:
        best_blue = blue_score
        best_epoch = epoch
        
        torch.save(model.state_dict(), '/content/drive/MyDrive/calibration_project/medical_dialogue_system/best_models_baseline.pth')
    print('cur epoch:%d, cur blue:%.5f, best epoch:%d, best blue:%.5f'%(epoch,blue_score, best_epoch, best_blue))

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


cur epoch:0, cur blue:0.00567, best epoch:0, best blue:0.00567
cur epoch:1, cur blue:0.13392, best epoch:1, best blue:0.13392
cur epoch:2, cur blue:0.16344, best epoch:2, best blue:0.16344
cur epoch:3, cur blue:0.18116, best epoch:3, best blue:0.18116
cur epoch:4, cur blue:0.20916, best epoch:4, best blue:0.20916
cur epoch:5, cur blue:0.23907, best epoch:5, best blue:0.23907
cur epoch:6, cur blue:0.28745, best epoch:6, best blue:0.28745
cur epoch:7, cur blue:0.31096, best epoch:7, best blue:0.31096
cur epoch:8, cur blue:0.33110, best epoch:8, best blue:0.33110
cur epoch:9, cur blue:0.33065, best epoch:8, best blue:0.33110
cur epoch:10, cur blue:0.33944, best epoch:10, best blue:0.33944
cur epoch:11, cur blue:0.34200, best epoch:11, best blue:0.34200
cur epoch:12, cur blue:0.34763, best epoch:12, best blue:0.34763
cur epoch:13, cur blue:0.34833, best epoch:13, best blue:0.34833
cur epoch:14, cur blue:0.34357, best epoch:13, best blue:0.34833
cur epoch:15, cur blue:0.32893, best epoch:13

In [None]:
from rouge_score import rouge_scorer
evaluate_matrics(model,test_loader)

The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


BLEU_SCORE1: 0.4256746608835476 Rouge: 0.45004191513969954 Meteor: 0.4090250663724253 PPL: 8.381151839452423
