In [1]:
import os
import numpy as np
import torch
import json
import math
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
from sklearn.cluster import DBSCAN
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
class dataset(DataLoader):
    def __init__(self, data_root, setlen):
        self.text_path = []
        self.label_path = []
        for i in range(setlen):
            if setlen==900 and (i == 70 or i==259):
                continue
            self.text_path.append(data_root+'problem-{}.txt'.format(str(i+1)))
            self.label_path.append(data_root+'truth-problem-{}.json'.format(str(i+1)))

    def __len__(self) -> int:
        return len(self.text_path)
    
    def __getitem__(self, item):
        paragraphs = []
        for line in open(self.text_path[item]):
            paragraphs.append(line)

        with open(self.label_path[item]) as json_file:
            truth = json.load(json_file)

        return (paragraphs, truth)
                



In [4]:
class StyleSpy(nn.Module):
    def __init__(self, n_features=512, hidden_size=1024,padding='max_length', dropout=0.1):
        super(StyleSpy,self).__init__()
        self.padding = padding
        self.berttokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.pooling = nn.AdaptiveAvgPool1d(1)
        self.ffn_hidden = nn.Sequential(nn.Linear(self.bert.config.hidden_size, hidden_size),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(hidden_size, n_features),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(n_features, n_features),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.LayerNorm(normalized_shape=n_features)
        )

        self.ffn_cls = nn.Sequential(nn.Linear(self.bert.config.hidden_size, hidden_size),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(hidden_size, n_features),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.Linear(n_features, n_features),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                    nn.LayerNorm(normalized_shape=n_features)
        )

        
        # Freeze the BERT part
        for param in self.bert.parameters():
            param.requires_grad = False

    def tokenize(self, text):
        input_ids = self.berttokenizer.encode(text, add_special_tokens=True, padding=self.padding, truncation=True, max_length=256)
        attention_mask = [int(id > 0) for id in input_ids]

        input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
        attention_mask = torch.tensor(attention_mask).unsqueeze(0).to(device)

        return input_ids, attention_mask

    def forward(self, text):
        input_ids, attention_mask = self.tokenize(text)
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:,1:-1,:]
        hidden_state = self.pooling(hidden_state.permute(0, 2, 1)).permute(0, 2, 1).squeeze(1)
        cls_token = outputs.pooler_output

        hidden_state = self.ffn_hidden(hidden_state)
        cls_token = self.ffn_cls(cls_token)
        features = torch.cat((cls_token,hidden_state),dim=1)

        return features
        

In [5]:
def adjust_learning_rate(epochs, batch_size, loader, step):
    max_steps = epochs * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = batch_size / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    return lr * lr

In [6]:
def loss_fn(truth, para_embeddings, threshold=4):
    #sim_loss = nn.MSELoss()
    d_loss = 0.0
    weights = truth['changes']
    n = truth['authors'].to(device)
    for i, weight in enumerate(weights):
        weight = weight.to(device)
        D2 = torch.cdist(para_embeddings[i], para_embeddings[i+1])
        l = (1-weight) * D2 + weight * torch.max(threshold - D2, torch.tensor(0.0).to(device))
        d_loss = d_loss + l
    d_loss = d_loss/len(weights)

    # embs = torch.cat(para_embeddings,dim=0)
    # distances = torch.cdist(embs,embs,p=2,compute_mode='donot_use_mm_for_euclid_dist')
    # clusters = torch.zeros(embs.shape[0], dtype=torch.long)
    # cluster_count = torch.Tensor([0]).to(device)
    # cluster_count.requires_grad=True
    # for i in range(embs.shape[0]):
    #     if clusters[i] != 0:
    #         continue
    #     cluster_count = cluster_count + 1
    #     clusters[i] = cluster_count

    #     for j in range(embs.shape[0]):
    #         if distances[i, j] <= 0.5:
    #             if clusters[j] == 0:
    #                 clusters[j] = cluster_count

    #distances = F.pairwise_distance(torch.cat(tensors, dim=0), torch.cat(tensors, dim=0))

    # c_loss = abs(cluster_count-n)
    #print(c_loss)
    
    return d_loss #+ c_loss

In [7]:
def evaluate(model, loader, threshold=4):
    
    acc_txt_count = 0
    txt_count = 0
    acc_class_count = 0
    para_results = []
    true_label = []

    loop = tqdm(enumerate(loader, start=len(loader)), total=len(loader), leave=False)
    for step, (paragraphs, truth) in loop:
        para_result = []
        para_embeddings = []
        for para in paragraphs:
            para_embedding = model(para[0])
            para_embeddings.append(para_embedding)
        
        n = truth['authors'].to(device)
        weights = truth['changes']
        for i, weight in enumerate(weights):
            score = torch.cdist(para_embeddings[i], para_embeddings[i+1])
            #print(score)
            para_result.append(int(score>threshold))
            true_label.append(weight.item())
        

        # acc_txt_count = acc_txt_count + all(x==y for x,y in zip(weights,para_result))
        # txt_count = txt_count + 1

        # dbscan = DBSCAN(eps=0.5, min_samples=1,metric='precomputed')
        # embs = torch.cat(para_embeddings,dim=0)
        # distances = torch.cdist(embs,embs,p=2,compute_mode='donot_use_mm_for_euclid_dist')
        # clusters = torch.zeros(embs.shape[0], dtype=torch.long)
        # cluster_count = torch.Tensor([0]).to(device)
        # cluster_count.requires_grad=True
        # for i in range(embs.shape[0]):
        #     if clusters[i] != 0:
        #         continue
        #     cluster_count = cluster_count + 1
        #     clusters[i] = cluster_count

        #     for j in range(embs.shape[0]):
        #         if distances[i, j] <= 0.5:
        #             if clusters[j] == 0:
        #                 clusters[j] = cluster_count
        # if cluster_count == n: acc_class_count+=1

        para_results = para_results+para_result

    F1 = f1_score(para_results, true_label)
    acc_para = sum(x==y for x,y in zip(para_results, true_label))
    acc_para = acc_para/len(para_results)
    # acc_txt = acc_txt_count/txt_count
    # acc_cluster = acc_class_count/txt_count

    #print(f"F1 score: {F1:.4f}, acc_para: {acc_para:.4f}, acc_txt: {acc_txt:.4f}, acc_clu: {acc_cluster:.4f} ")
    print(f"F1 score: {F1:.4f}, acc_para: {acc_para:.4f}")
    return F1,acc_para#,acc_txt,acc_class_count



In [8]:
def train(model, trainloader, valloader, epochs, optimizer, threshold, save_freq, loss_fn=loss_fn):
    sim_loss = nn.MSELoss()
    writer = SummaryWriter("./log/")
    acc = 0.0
    for epoch in range(epochs):
        loop = tqdm(enumerate(trainloader, start=epoch * len(trainloader)), total=len(trainloader), leave=False)
        for step, (paragraphs, truth) in loop:
            optimizer.param_groups[0]['lr'] = 0.000001#adjust_learning_rate(epochs, batch_size, trainloader, step)
            optimizer.zero_grad()
            para_embeddings = []
            for para in paragraphs:
                para_embedding = model(para[0])
                para_embeddings.append(para_embedding)
            
            loss = loss_fn(truth, para_embeddings,threshold=threshold)
            
            
            loss.backward()
            optimizer.step()

            writer.add_scalar("Loss/train", loss, epoch)

            if step % int(save_freq) == 0 and step:
                with open(os.path.join("./log/", 'logs.txt'), 'a') as log_file:
                    log_file.write(f'Epoch: {epoch}, Step: {step}, Train loss: {loss.cpu().detach().numpy()} \n')

                state = dict(epoch=epoch + 1, model=model.state_dict(),
                         optimizer=optimizer.state_dict())

                #torch.save(state, os.path.join('.', 'checkpoints', f'checkpoint_{step}_steps.pth'))
            if step % 4000 == 0 and step:
                with torch.no_grad():
                    F1,acc_para = evaluate(model,valloader,threshold=threshold)
                    if acc_para > acc:
                        acc = acc_para
                        torch.save(model, os.path.join('.', 'checkpoints', f'best_{acc}_acc_{F1}_F1_{threshold}_thre.pth'))
            loop.set_description(f'Epoch [{epoch}/{epochs}]')
            loop.set_postfix(loss = loss.cpu().detach().numpy())
            
        print(f'Loss for epoch {epoch} is {loss.cpu().detach().numpy()}')
    print('End of the Training. Saving final checkpoints.')
    state = dict(epoch=epochs, model=model.state_dict(),
                 optimizer=optimizer.state_dict())
    torch.save(state, os.path.join('.', 'checkpoints',  'final_checkpoint.pth'))
    writer.flush()
    writer.close()    
                
                

In [9]:
"""
dataset1: training 4200, validation 900
dataset2: training 4200, validation 900
dataset3: training 4200, validation 900
label format: {'author': int -> number of authors occur in this file 
                'changes': int list -> length equals to num_paragraghs-1
                                        every time a new paragragh appears-> 0=unchanged, 1=changed
                }
"""
training_path1 = "./release/pan23-multi-author-analysis-dataset1/pan23-multi-author-analysis-dataset1-train/"
val_path1 = "./release/pan23-multi-author-analysis-dataset1/pan23-multi-author-analysis-dataset1-validation/"
training_path2 = "./release/pan23-multi-author-analysis-dataset2/pan23-multi-author-analysis-dataset2-train/"
val_path2 = "./release/pan23-multi-author-analysis-dataset2/pan23-multi-author-analysis-dataset2-validation/"
training_path3 = "./release/pan23-multi-author-analysis-dataset3/pan23-multi-author-analysis-dataset3-train/"
val_path3 = "./release/pan23-multi-author-analysis-dataset3/pan23-multi-author-analysis-dataset3-validation/"

In [16]:
epochs = 100
lr = 0.1
batch_size = 1
threshold = 4

In [10]:
# model = torch.load('./checkpoints/best_0.7297875374304862_acc_0.6318243637070139_F1_4_thre.pth')
# model = model.to(device)

In [11]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                lr=0.1,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0,
                amsgrad=False)

In [12]:
Training_set1 = dataset(data_root=training_path1, setlen=4200)
trainingloader1 = DataLoader(dataset=Training_set1,batch_size=1,shuffle=True)
Val_set1 = dataset(data_root=val_path1,setlen=900)
valloader1 = DataLoader(dataset=Val_set1,batch_size=1,shuffle=True)

In [13]:
Training_set2 = dataset(data_root=training_path2, setlen=4200)
trainingloader2 = DataLoader(dataset=Training_set2,batch_size=1,shuffle=True)
Val_set2 = dataset(data_root=val_path2,setlen=900)
valloader2 = DataLoader(dataset=Val_set2,batch_size=1,shuffle=True)

In [14]:
Training_set3 = dataset(data_root=training_path3, setlen=4200)
trainingloader3 = DataLoader(dataset=Training_set3,batch_size=1,shuffle=True)
Val_set3 = dataset(data_root=val_path3,setlen=900)
valloader3 = DataLoader(dataset=Val_set3,batch_size=1,shuffle=True)

In [None]:
model = StyleSpy()

In [None]:
train(model,trainingloader2,valloader2,epochs,optimizer,threshold,1000)