In [1]:
import json
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from tqdm.auto import tqdm
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"

In [16]:
train_data = pd.read_json("subtaskA_train_monolingual.jsonl", lines=True)
test_data = pd.read_json("subtaskA_dev_monolingual.jsonl", lines=True)

In [17]:
len(train_data), len(test_data)

(119757, 5000)

In [18]:
train_data.head()

Unnamed: 0,text,label,model,source,id
0,Forza Motorsport is a popular racing game that...,1,chatGPT,wikihow,0
1,Buying Virtual Console games for your Nintendo...,1,chatGPT,wikihow,1
2,Windows NT 4.0 was a popular operating system ...,1,chatGPT,wikihow,2
3,How to Make Perfume\n\nPerfume is a great way ...,1,chatGPT,wikihow,3
4,How to Convert Song Lyrics to a Song'\n\nConve...,1,chatGPT,wikihow,4


In [19]:
np.unique(train_data["model"]), np.unique(train_data["label"])

(array(['chatGPT', 'cohere', 'davinci', 'dolly', 'human'], dtype=object),
 array([0, 1], dtype=int64))

Обрабатываем данные

In [13]:
class DetectionDateset(Dataset):
    def __init__(self, 
                 data,
                 n_samples=10000):
        super().__init__()

        self.data = data.iloc[:n_samples]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data["text"][idx], self.data["label"][idx]

class Collator:
    def __init__(self, 
                 tokenizer,
                 max_len):
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __call__(self, batch):
        texts = [elem[0] for elem in batch]
        labels = [elem[1] for elem in batch]

        tokenized = self.tokenizer(texts, 
                                   return_tensors="pt", 
                                   truncation=True,
                                   padding=True, 
                                   max_length=self.max_len,
                                   add_special_tokens=True)
        input_ids = tokenized["input_ids"]
        attention_mask = tokenized["attention_mask"]
        
        
        return input_ids, attention_mask, torch.tensor(labels)

Прикручиваем голову для классификации к tiny-bert-у

In [14]:
class DetectionModel(nn.Module):
    def __init__(self, 
                 pretrained_model_path,
                 hidden_dim,
                 output_dim,
                 drouput):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(pretrained_model_path)
        self.clf_head = nn.Sequential(nn.Linear(self.backbone.encoder.layer[1].output.dense.out_features, hidden_dim),
                                      nn.GELU(),
                                      nn.Dropout(p=drouput),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.GELU(),
                                      nn.Dropout(p=drouput),
                                      nn.Linear(hidden_dim, output_dim))
        
    def forward(self, input_ids, attention_mask):
        if len(input_ids.shape) == 1:
            input_ids = input_ids.unsqueeze(0)

        embeds = self.backbone(input_ids, attention_mask)["last_hidden_state"]

        cls_output = self.clf_head(embeds[:,0,:])

        return cls_output

In [15]:
def train_epoch(train_loader, model, loss_function, optimizer, callback=None):
    epoch_loss = 0
    total = 0
    for it, (input_ids, attention_mask, labels) in enumerate(tqdm(train_loader, leave=False)):
              
        batch_loss = train_on_batch(model, 
                                    input_ids, 
                                    attention_mask, 
                                    labels, 
                                    optimizer, 
                                    loss_function)
        
        if callback is not None:
            with torch.no_grad():
                callback(model, batch_loss)
            
        epoch_loss += batch_loss * len(labels)
        total += len(labels)
    
    return epoch_loss / total


def train_on_batch(model, 
                   input_ids, 
                   attention_mask, 
                   labels,
                   optimizer, 
                   loss_function):
    input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
    model.train()
    optimizer.zero_grad()
    preds = model(input_ids, attention_mask)
    loss = loss_function(preds, labels)
    loss.backward()
    optimizer.step()

    return loss.detach().cpu().item()


def trainer(count_of_epoch, 
            batch_size, 
            loader,
            model, 
            loss_function,
            optimizer,
            lr = 0.001,
            callback = None):

    optima = optimizer(model.parameters(), lr=lr)
    
    iterations = tqdm(range(count_of_epoch), desc='epoch')
    iterations.set_postfix({'train epoch loss': np.nan})
    for it in iterations:
        
        
        epoch_loss = train_epoch(train_loader=loader, 
                    model=model, 
                    loss_function=loss_function,
                    optimizer=optima, 
                    callback=callback)
        
        iterations.set_postfix({'train epoch loss': epoch_loss})


class Callback():
    def __init__(self, writer, test_loader, loss_function, delimeter=100, batch_size=64):
        self.step = 0
        self.writer = writer
        self.delimeter = delimeter
        self.loss_function = loss_function
        self.batch_size = batch_size

        self.loader = test_loader

    def forward(self, model, loss):
        self.step += 1
        self.writer.add_scalar('LOSS/train', loss, self.step)
        
        if self.step % self.delimeter == 0:
            
            pred = []
            real = []
            model.eval()
            with torch.no_grad():
                for it, (input_ids, attention_mask, labels) in enumerate(tqdm(self.loader, leave=False)):

                    input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
    
                    output = model(input_ids, attention_mask).detach()
    
                    pred.extend(torch.argmax(output, dim=-1).cpu().view(-1).tolist())
                    real.extend(labels.view(-1).tolist())
                    
                test_acc = np.mean(np.array(pred) == np.array(real))
                
                self.writer.add_scalar('Acc/test', test_acc, self.step)

          
    def __call__(self, model, loss):
        return self.forward(model, loss)

In [10]:
%load_ext tensorboard
%tensorboard --logdir ./ --port=6002

In [11]:
train_dataset = DetectionDateset(train_data, n_samples=-1)
test_dataset = DetectionDateset(test_data, n_samples=5000)

In [20]:
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")

batch_size = 100
collator = Collator(tokenizer, max_len=512)


loss_function = nn.CrossEntropyLoss(
    
)


n_epochs = 5
test_step_size = 50
lr = 3e-4
hidden_dim = 256
output_dim = 2
dropout = 0.1
optimizer = torch.optim.Adam

train_loader = DataLoader(train_dataset, 
                         shuffle=True, 
                         batch_size=batch_size,
                         collate_fn=collator)
test_loader = DataLoader(test_dataset, 
                         shuffle=False, 
                         batch_size=batch_size,
                         collate_fn=collator)

model = DetectionModel(pretrained_model_path="prajjwal1/bert-tiny",
                       hidden_dim=hidden_dim,
                       output_dim=output_dim,
                       drouput=dropout).to(device)

writer = SummaryWriter(log_dir="./run0")

callback = Callback(writer, 
                    test_loader, 
                    loss_function, 
                    delimeter=test_step_size)

trainer(count_of_epoch=n_epochs, 
        batch_size=batch_size, 
        loader=train_loader,
        model=model, 
        loss_function=loss_function,
        optimizer=optimizer,
        lr=lr,
        callback=callback)

epoch:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1198 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1198 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1198 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1198 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/1198 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Не знаю что комментировать, взяли tiny bert, взяли выход напротив cls токена, засунули в голову для классификации, оно завелось и как-то обучилось, выдает accuracy > 0.5 и ладно. Для улучшения качества можно взять модельку побольше и покрутить немножко параметры, но задача, кажется, не в выбивании скора, поэтому этим ограничимся.