In [2]:
import pandas as pd
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torchmetrics import Accuracy, F1Score

df = pd.read_csv('/kaggle/input/qefasfas/train.tsv')
df['class'] = df['class'].astype(int) 


train_df, val_df = train_test_split(df, test_size=0.2, random_state=43)


class TweetDataset(Dataset):
    def __init__(self, texts, targets):
        self.texts = texts
        self.targets = targets

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {'text': self.texts[idx], 'class': self.targets[idx]}


class Collator:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, batch):
        texts = [item['text'] for item in batch]
        targets = [item['class'] for item in batch]
        encoding = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'],
            'attention_mask': encoding['attention_mask'],
            'labels': torch.tensor(targets, dtype=torch.long)
        }

class ClassificationModel(pl.LightningModule):
    def __init__(self, model_name="deepvk/USER-base", lr=2e-5):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.model.config.hidden_size, 1)
        
        self.register_buffer('pos_weight', torch.tensor([5.0]))
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)

        self.train_acc = Accuracy(task="binary")
        self.val_acc = Accuracy(task="binary")
        self.train_f1 = F1Score(task="binary")
        self.val_f1 = F1Score(task="binary")

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state.mean(dim=1)
        logits = self.classifier(pooled)
        return logits.squeeze()

    def training_step(self, batch, batch_idx):
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits, batch['labels'].float())
        preds = torch.sigmoid(logits) > 0.5
        
        self.train_acc(preds, batch['labels'])
        self.train_f1(preds, batch['labels'])
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        self.log('train_acc', self.train_acc.compute(), prog_bar=True)
        self.log('train_f1', self.train_f1.compute(), prog_bar=True)
        self.train_acc.reset()
        self.train_f1.reset()

    def validation_step(self, batch, batch_idx):
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits, batch['labels'].float())
        preds = torch.sigmoid(logits) > 0.5
        
        self.val_acc(preds, batch['labels'])
        self.val_f1(preds, batch['labels'])
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def on_validation_epoch_end(self):
        self.log('val_acc', self.val_acc.compute(), prog_bar=True)
        self.log('val_f1', self.val_f1.compute(), prog_bar=True)
        self.val_acc.reset()
        self.val_f1.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)


model_name = "deepvk/USER-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
collator = Collator(tokenizer)

train_loader = DataLoader(
    TweetDataset(train_df['tweet'].tolist(), train_df['class'].tolist()),
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collator,
    persistent_workers=True
)

val_loader = DataLoader(
    TweetDataset(val_df['tweet'].tolist(), val_df['class'].tolist()),
    batch_size=32,
    num_workers=4,
    collate_fn=collator,
    persistent_workers=True
)


checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_f1',
    mode='max',
    save_top_k=1,
    filename='best-{epoch}-{val_f1:.2f}'
)


trainer = pl.Trainer(
    max_epochs=7,
    accelerator='gpu',
    devices=-1,  
    strategy='ddp_notebook',
    callbacks=[checkpoint],
    enable_progress_bar=True,
    log_every_n_steps=10
)

model = ClassificationModel()
trainer.fit(model, train_loader, val_loader)

best_model_path = checkpoint.best_model_path
print(f"Best model saved at: {best_model_path}")

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('val_f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:434: It is recommended to use `self.log('train_f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Best model saved at: /kaggle/working/lightning_logs/version_26/checkpoints/best-epoch=3-val_f1=0.53.ckpt


In [13]:
import numpy as np
def predict_with_existing_model(model, data_loader, device='cuda', has_labels=True):
    model.to(device)
    model.eval()
    
    probabilities = []
    all_targets = []
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            logits = model(input_ids, attention_mask)
            probas = torch.sigmoid(logits)
            
            probabilities.extend(probas.cpu().numpy())
            if has_labels:
                all_targets.extend(batch['labels'].cpu().numpy())
    
    # Подбор трешхолда если есть метки
    if has_labels and len(all_targets) > 0:
        best_thresh, best_f1 = find_optimal_threshold(all_targets, np.array(probabilities))
        print(f"Optimal threshold: {best_thresh:.4f} with F1: {best_f1:.4f}")
    else:
        best_thresh = 0.015
        print("Using default threshold 0.5")
    
    predictions = (np.array(probabilities) >= best_thresh).astype(int)
    return predictions, probabilities

test_df = pd.read_csv("/kaggle/input/qefasfas/test.tsv")
test_dataset = TweetDataset(test_df['tweet'].tolist(), [0]*len(test_df))  
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=collator)

test_predictions, test_probas = predict_with_existing_model(
    model=model,
    data_loader=test_loader,
    device='cuda',
    has_labels=False
)


test_df['class'] = test_predictions
test_df['probability'] = test_probas
test_df[['id','class']].to_csv('predictions.csv', index=False)

Using default threshold 0.5
