In [None]:
# !pip install numpy
# !pip install pandas
# !pip install scikit-learn
# !pip install torch
# !pip install transformers


from transformers import RobertaModel, RobertaTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, hamming_loss, roc_auc_score, average_precision_score
from collections import defaultdict
from torch.amp import autocast, GradScaler
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import time

## Hyperparameters
MAX_LEN = 512
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32
EPOCHS = 12
LEARNING_RATE = 1e-05
THRESHOLD = 0.5 # threshold for the sigmoid

## Dataset Class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_len, target_list):
        self.tokenizer = tokenizer
        self.df = df
        self.title = list(df['File Contents'])
        self.targets = self.df[target_list].values
        self.max_len = max_len

    def __len__(self):
        return len(self.title)

    def __getitem__(self, index):
        title = str(self.title[index])
        title = " ".join(title.split())
        inputs = self.tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'token_type_ids': inputs["token_type_ids"].flatten(),
            'targets': torch.FloatTensor(self.targets[index]),
            'title': title
        }

## Data
train_file = 'train.csv'
test_file = 'test.csv'
train_val_df = pd.read_csv(train_file)
test_df = pd.read_csv(test_file)

train_df, val_df = train_test_split(train_val_df, test_size=0.2, random_state=42)

target_list = list(train_df.columns[1:])

## Tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


train_dataset = CustomDataset(train_df, tokenizer, MAX_LEN, target_list)
valid_dataset = CustomDataset(val_df, tokenizer, MAX_LEN, target_list)
test_dataset = CustomDataset(test_df, tokenizer, MAX_LEN, target_list)

#print(train_dataset[0])

## Data Loader
train_data_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

val_data_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=VALID_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

test_data_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

## Model

class RobertaInceptionAttentionImproved(nn.Module):
    def __init__(self, num_classes):
        super(RobertaInceptionAttentionImproved, self).__init__()

        # Initialize RoBERTa model
        self.roberta = RobertaModel.from_pretrained('roberta-base', output_hidden_states=True)

        # Dropout layer after RoBERTa output
        self.dropout = nn.Dropout(0.3)

        # Inception block with multiple kernel sizes (32 output channels each)
        self.conv2 = nn.Conv1d(in_channels=768, out_channels=128, kernel_size=2, padding=0)
        self.conv3 = nn.Conv1d(in_channels=768, out_channels=128, kernel_size=3, padding=0)
        self.conv5 = nn.Conv1d(in_channels=768, out_channels=128, kernel_size=5, padding=0)
        self.conv7 = nn.Conv1d(in_channels=768, out_channels=128, kernel_size=7, padding=0)

        # Self-attention layer after Inception block
        self.attention = nn.MultiheadAttention(embed_dim=1280, num_heads=8, batch_first=True)

        # Additional dense layer with LayerNorm for refined feature interaction
        self.dense = nn.Sequential(
            nn.Linear(1280, 512),
            nn.ReLU(),
            nn.LayerNorm(512)
        )

        # Final dropout and classification layer
        self.final_dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, input_ids, attention_mask):
        # RoBERTa branch
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # Shape: (batch_size, seq_length, 768)

        # Apply dropout to RoBERTa embeddings
        hidden_states = self.dropout(hidden_states)

        # Inception block 
        hidden_states = hidden_states.permute(0, 2, 1) 

        # Apply convolutions without padding, then pad manually
        conv2_output = F.pad(self.conv2(hidden_states), (0, 1))  # Padding to match max seq length
        conv3_output = F.pad(self.conv3(hidden_states), (1, 1))  # Adjust to max seq length
        conv5_output = F.pad(self.conv5(hidden_states), (2, 2))  # Adjust to max seq length
        conv7_output = F.pad(self.conv7(hidden_states), (3, 3))  # Adjust to max seq length

        # Concatenate along the channel dimension
        inception_output = torch.cat([conv2_output, conv3_output, conv5_output, conv7_output], dim=1)  # Shape: (batch_size, 128, seq_length)
        inception_output = inception_output.permute(0, 2, 1)  # Back to (batch_size, seq_length, 128)

        # Concatenate Inception outputs with original RoBERTa embeddings
        concatenated_features = torch.cat([hidden_states.permute(0, 2, 1), inception_output], dim=2)  # Shape: (batch_size, seq_length, 896)


        key_padding_mask = ~attention_mask.bool()  # Shape: (batch_size, seq_length)

        attn_output, _ = self.attention(
            concatenated_features,
            concatenated_features,
            concatenated_features,
            key_padding_mask=key_padding_mask
        )  # Shape: (batch_size, seq_length, 896)

        # Global mean pooling over sequence length
        pooled_output = F.adaptive_avg_pool1d(attn_output.permute(0, 2, 1), output_size=1).squeeze(-1)

        # dense layer
        dense_output = self.dense(pooled_output)

        # Final dropout and classification layer
        dense_output = self.final_dropout(dense_output)
        logits = self.fc(dense_output)  # Shape: (batch_size, num_classes)

        return logits

## Setting the model
model = RobertaInceptionAttentionImproved(num_classes=len(target_list))
model.to(device)

## Loss Fn
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

# optimizer
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

## Training function
def train_model(training_loader, model, optimizer, accumulation_steps=1): 
    losses = []
    correct_predictions = 0
    num_samples = 0
    total_batches = len(training_loader)

    # Set model to training mode (activate dropout, batch norm)
    model.train()

    for batch_idx, data in enumerate(training_loader):
        ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
        mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
        targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

        outputs = model(ids, mask)
        loss = loss_fn(outputs, targets)
        losses.append(loss.item())

        outputs = torch.sigmoid(outputs).cpu().detach().numpy().round()
        targets = targets.cpu().detach().numpy()
        correct_predictions += np.sum(outputs == targets)
        num_samples += targets.size

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        optimizer.zero_grad()

        # Clearing GPU cache
        torch.cuda.empty_cache()

    return model, float(correct_predictions) / num_samples, np.mean(losses)


def eval_model(validation_loader, model, threshold=0.5, target_list=None):
    model.eval()
    final_targets = []
    final_outputs = []
    final_probs = []
    losses = []

    with torch.no_grad():
        for data in validation_loader:
            ids = data['input_ids'].to(device, dtype=torch.long, non_blocking=True)
            mask = data['attention_mask'].to(device, dtype=torch.long, non_blocking=True)
            token_type_ids = data['token_type_ids'].to(device, dtype=torch.long, non_blocking=True)
            targets = data['targets'].to(device, dtype=torch.float, non_blocking=True)

            outputs = model(ids, mask)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().detach().numpy()
            targets = targets.cpu().detach().numpy()
            final_outputs.extend(probs >= threshold)
            final_probs.extend(probs)
            final_targets.extend(targets)

            torch.cuda.empty_cache()

    final_outputs = np.array(final_outputs) >= threshold
    final_probs = np.array(final_probs)
    final_targets = np.array(final_targets)

    acc = accuracy_score(final_targets, final_outputs)
    f1 = f1_score(final_targets, final_outputs, average='weighted')
    precision = precision_score(final_targets, final_outputs, average='weighted')
    recall = recall_score(final_targets, final_outputs, average='weighted')
    hamming = hamming_loss(final_targets, final_outputs)

    auc_roc = roc_auc_score(final_targets, final_probs, average='weighted', multi_class='ovr')
    aupr = average_precision_score(final_targets, final_probs, average='weighted')

    average_loss = np.mean(losses)

    print(f"Accuracy: {acc}")
    print(f"F1 Score: {f1}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"Hamming Loss: {hamming}")
    print(f"Average Loss: {average_loss}")
    print(f"AUC-ROC: {auc_roc}")
    print(f"AUPR: {aupr}")
    print("\nClassification Report:\n", classification_report(final_targets, final_outputs, target_names=target_list))

    return f1, average_loss


# Learning Rate Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training & Evaluation Loop
start = time.time()

history = defaultdict(list)
best_f1 = 0.0

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    model, train_acc, train_loss = train_model(train_data_loader, model, optimizer)
    val_f1, val_loss = eval_model(val_data_loader, model)

    history['train_acc'].append(train_acc)
    history['train_loss'].append(train_loss)
    history['val_f1'].append(val_f1)
    history['val_loss'].append(val_loss)

    scheduler.step()  # Step scheduler after each epoch

    if val_f1 > best_f1:
        torch.save(model.state_dict(), "ohsumed_RIncDNet128_32_base_best.bin")
        best_f1 = val_f1

end = time.time()
print(f"Total training and evaluation time: {end - start} seconds")


## Testing
print("\n\nTesting\n\n")
model = RobertaInceptionAttentionImproved(num_classes=len(target_list))
model.load_state_dict(torch.load("ohsumed_RIncDNet128_32_base_best.bin"))
model = model.to(device)

start = time.time()
eval_model(test_data_loader, model)
end = time.time()
print(f"Total test-set evaluation time: {end - start} seconds")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/12


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.21144674085850557
F1 Score: 0.3554755611091108
Precision: 0.591520654485336
Recall: 0.290521327014218
Hamming Loss: 0.05799405543651068
Average Loss: 0.18617926277220248
AUC-ROC: 0.8588159328530781
AUPR: 0.5806460304716774

Classification Report:
               precision    recall  f1-score   support

           0       0.89      0.16      0.27       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.83      0.80      0.82       234
           4       0.00      0.00      0.00        51
           5       0.60      0.32      0.42       109
           6       0.00      0.00      0.00        16
           7       0.80      0.16      0.27        99
           8       0.00      0.00      0.00        20
           9       0.95      0.16      0.28       122
          10       0.00      0.00      0.00        26
          11       0.00      0.00      0.00       110
          12       0.69      0.23    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3322734499205087
F1 Score: 0.5030878484601247
Precision: 0.7036692022287175
Recall: 0.44218009478672987
Hamming Loss: 0.049319140111978986
Average Loss: 0.1591376055032015
AUC-ROC: 0.8876217540813587
AUPR: 0.6450076917156568

Classification Report:
               precision    recall  f1-score   support

           0       0.73      0.39      0.51       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.86      0.75      0.80       234
           4       0.53      0.35      0.42        51
           5       0.66      0.69      0.67       109
           6       0.00      0.00      0.00        16
           7       0.79      0.54      0.64        99
           8       0.75      0.15      0.25        20
           9       0.77      0.34      0.47       122
          10       0.81      0.50      0.62        26
          11       0.84      0.68      0.75       110
          12       0.74      0.49  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.34181240063593005
F1 Score: 0.5410518918511331
Precision: 0.7506668519555709
Recall: 0.46445497630331756
Hamming Loss: 0.04700352526439483
Average Loss: 0.14377968590706586
AUC-ROC: 0.9043955049267689
AUPR: 0.6891741158965891

Classification Report:
               precision    recall  f1-score   support

           0       0.83      0.29      0.43       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.84      0.85      0.85       234
           4       0.58      0.49      0.53        51
           5       0.68      0.61      0.64       109
           6       0.00      0.00      0.00        16
           7       0.77      0.52      0.62        99
           8       0.54      0.35      0.42        20
           9       0.86      0.39      0.54       122
          10       0.70      0.54      0.61        26
          11       0.85      0.74      0.79       110
          12       0.76      0.66 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3728139904610493
F1 Score: 0.6098059742186299
Precision: 0.754362577528598
Recall: 0.5417061611374407
Hamming Loss: 0.04420405059791249
Average Loss: 0.13519354909658432
AUC-ROC: 0.9104282953489152
AUPR: 0.7067854062237813

Classification Report:
               precision    recall  f1-score   support

           0       0.81      0.57      0.67       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.83      0.88      0.86       234
           4       0.58      0.49      0.53        51
           5       0.68      0.62      0.65       109
           6       0.00      0.00      0.00        16
           7       0.84      0.52      0.64        99
           8       0.67      0.30      0.41        20
           9       0.79      0.45      0.57       122
          10       0.67      0.54      0.60        26
          11       0.87      0.73      0.79       110
          12       0.80      0.60    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.36327503974562797
F1 Score: 0.5982654220761383
Precision: 0.7783796712133102
Recall: 0.5127962085308057
Hamming Loss: 0.04492984032625976
Average Loss: 0.1317170662805438
AUC-ROC: 0.9139434224984487
AUPR: 0.716850509667906

Classification Report:
               precision    recall  f1-score   support

           0       0.88      0.48      0.62       102
           1       0.00      0.00      0.00        36
           2       0.00      0.00      0.00        16
           3       0.89      0.76      0.82       234
           4       0.55      0.55      0.55        51
           5       0.70      0.68      0.69       109
           6       1.00      0.06      0.12        16
           7       0.81      0.52      0.63        99
           8       0.50      0.30      0.37        20
           9       0.86      0.36      0.51       122
          10       0.68      0.50      0.58        26
          11       0.89      0.71      0.79       110
          12       0.83      0.64    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3815580286168522
F1 Score: 0.6335671418465139
Precision: 0.7549327104376744
Recall: 0.5706161137440758
Hamming Loss: 0.04365106794774314
Average Loss: 0.12946748211979867
AUC-ROC: 0.918207291625712
AUPR: 0.7204351361203469

Classification Report:
               precision    recall  f1-score   support

           0       0.77      0.66      0.71       102
           1       0.42      0.14      0.21        36
           2       0.50      0.06      0.11        16
           3       0.87      0.83      0.85       234
           4       0.56      0.63      0.59        51
           5       0.70      0.64      0.67       109
           6       1.00      0.19      0.32        16
           7       0.68      0.66      0.67        99
           8       0.50      0.45      0.47        20
           9       0.73      0.50      0.59       122
          10       0.59      0.62      0.60        26
          11       0.83      0.75      0.78       110
          12       0.77      0.64    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3847376788553259
F1 Score: 0.6353060321131371
Precision: 0.7623824770289784
Recall: 0.5658767772511848
Hamming Loss: 0.043201769544480545
Average Loss: 0.12667267471551896
AUC-ROC: 0.9187133669901923
AUPR: 0.7207708715677534

Classification Report:
               precision    recall  f1-score   support

           0       0.83      0.62      0.71       102
           1       0.50      0.08      0.14        36
           2       0.50      0.06      0.11        16
           3       0.88      0.79      0.83       234
           4       0.58      0.55      0.57        51
           5       0.70      0.72      0.71       109
           6       1.00      0.19      0.32        16
           7       0.73      0.61      0.66        99
           8       0.47      0.40      0.43        20
           9       0.73      0.50      0.60       122
          10       0.59      0.50      0.54        26
          11       0.86      0.78      0.82       110
          12       0.85      0.62  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3942766295707472
F1 Score: 0.6496119220744921
Precision: 0.7576057003732459
Recall: 0.5876777251184834
Hamming Loss: 0.04278703255685353
Average Loss: 0.12631057240068913
AUC-ROC: 0.9192475688488765
AUPR: 0.7223580617802603

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.60      0.70       102
           1       0.71      0.14      0.23        36
           2       0.67      0.12      0.21        16
           3       0.88      0.81      0.84       234
           4       0.61      0.65      0.63        51
           5       0.69      0.70      0.69       109
           6       0.75      0.19      0.30        16
           7       0.73      0.59      0.65        99
           8       0.50      0.40      0.44        20
           9       0.76      0.47      0.58       122
          10       0.64      0.54      0.58        26
          11       0.86      0.79      0.82       110
          12       0.81      0.62   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.39666136724960255
F1 Score: 0.6459425563316167
Precision: 0.7659211505611402
Recall: 0.5791469194312796
Hamming Loss: 0.042337734153590934
Average Loss: 0.1262842308729887
AUC-ROC: 0.9196702835838477
AUPR: 0.7217662134767878

Classification Report:
               precision    recall  f1-score   support

           0       0.82      0.62      0.70       102
           1       0.62      0.14      0.23        36
           2       0.80      0.25      0.38        16
           3       0.88      0.80      0.84       234
           4       0.61      0.59      0.60        51
           5       0.70      0.71      0.70       109
           6       0.75      0.19      0.30        16
           7       0.74      0.59      0.66        99
           8       0.50      0.45      0.47        20
           9       0.75      0.48      0.58       122
          10       0.62      0.58      0.60        26
          11       0.84      0.79      0.81       110
          12       0.83      0.64  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.39666136724960255
F1 Score: 0.6506684947999588
Precision: 0.7633770688688767
Recall: 0.5843601895734597
Hamming Loss: 0.04216492707541301
Average Loss: 0.12565878611057996
AUC-ROC: 0.9196252720917195
AUPR: 0.7214535813609168

Classification Report:
               precision    recall  f1-score   support

           0       0.82      0.62      0.70       102
           1       0.50      0.14      0.22        36
           2       0.80      0.25      0.38        16
           3       0.87      0.82      0.84       234
           4       0.61      0.59      0.60        51
           5       0.70      0.71      0.70       109
           6       0.75      0.19      0.30        16
           7       0.75      0.60      0.66        99
           8       0.53      0.45      0.49        20
           9       0.74      0.48      0.58       122
          10       0.64      0.54      0.58        26
          11       0.84      0.79      0.82       110
          12       0.86      0.64  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.39666136724960255
F1 Score: 0.6506684947999588
Precision: 0.7633770688688767
Recall: 0.5843601895734597
Hamming Loss: 0.04216492707541301
Average Loss: 0.12565878611057996
AUC-ROC: 0.9196252720917195
AUPR: 0.7214535813609168

Classification Report:
               precision    recall  f1-score   support

           0       0.82      0.62      0.70       102
           1       0.50      0.14      0.22        36
           2       0.80      0.25      0.38        16
           3       0.87      0.82      0.84       234
           4       0.61      0.59      0.60        51
           5       0.70      0.71      0.70       109
           6       0.75      0.19      0.30        16
           7       0.75      0.60      0.66        99
           8       0.53      0.45      0.49        20
           9       0.74      0.48      0.58       122
          10       0.64      0.54      0.58        26
          11       0.84      0.79      0.82       110
          12       0.86      0.64  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Accuracy: 0.3958664546899841
F1 Score: 0.6467896062671984
Precision: 0.7655216374020611
Recall: 0.5786729857819906
Hamming Loss: 0.04230317273795535
Average Loss: 0.12553037758916616
AUC-ROC: 0.9194071478081028
AUPR: 0.7213782271356083

Classification Report:
               precision    recall  f1-score   support

           0       0.82      0.62      0.70       102
           1       0.56      0.14      0.22        36
           2       0.75      0.19      0.30        16
           3       0.88      0.81      0.84       234
           4       0.61      0.59      0.60        51
           5       0.70      0.70      0.70       109
           6       0.75      0.19      0.30        16
           7       0.75      0.60      0.66        99
           8       0.53      0.45      0.49        20
           9       0.76      0.47      0.58       122
          10       0.64      0.54      0.58        26
          11       0.84      0.79      0.82       110
          12       0.85      0.62   

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load("ohsumed_RIncDNet128_32_base_best.bin"))
