In [1]:
from Bio import SeqIO
import csv
import re
from dataclasses import dataclass

In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from concurrent.futures import ProcessPoolExecutor, as_completed
from torch.utils.data import DataLoader
from datasets import load_dataset
from torch.utils.data import Dataset
from typing import Dict, Sequence
from dataclasses import dataclass
from torch.nn import Softmax
from Bio import SeqIO
from torch import nn
import transformers
import numpy as np
import threading
import argparse
import torch
import csv
import re
import os

In [12]:
input_seq_file = "C:/Users/anjan/Downloads/ncbi_dataset/ncbi_dataset/data/gene.fna"
mut_file = "C:/Users/anjan/variant_data.csv"

In [4]:
for record in SeqIO.parse(input_seq_file, "fasta"):
    sequence = record.seq
    header = record.description
    des = header.split(' ')
    position = des[0].split(':')
    start, end = position[1].split('-')
print(len(sequence))
print(start, end)
start = int(start)
end = int(end)

84761
32315508 32400268


In [91]:
import vcf
variant_data = "C:/Users/anjan/Downloads/ClinVar variants with precise endpoints (1).VCF"
vcf_reader = vcf.Reader(open(variant_data, 'r'))
variant = []

# Iterate over each variant in the VCF file
for record in vcf_reader:
    # Access variant information using record attributes 
    if record.POS >= start and record.POS <= end:
        variant.append({
            "variant_id" : record.ID,
            "pos" : record.POS,
            "ref" : record.REF,
            "alt" : record.ALT[0],
            "variant": record.INFO["CLNVC"][0]
        })
        
fields = ['variant_id', 'pos', 'ref', 'alt', 'variant']

with open("variant_data.csv", 'w') as csvfile:
    # creating a csv dict writer object
    writer = csv.DictWriter(csvfile, fieldnames=fields, lineterminator = '\n')
    # writing headers (field names)
    writer.writeheader()
    writer.writerows(variant)



In [5]:
# Read mutations from CSV file
mutations = []
with open(mut_file , 'r') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)  # Skip header if present
    for row in reader:
        if row:
            start_position = int(row[1])-start
            end_position = start_position
            consequence = row[4]
            ref = row[2]
            alt = row[3]
            mutations.append((start_position, ref, alt, consequence))

In [32]:
snv = "single_nucleotide_variant"
dup = "Duplication"
ins = "Insertion"
delt = "Deletion"
last_index = 0
def embed_mutations(sequence, mutations):
    alt_length = 0
    embedded_sequence = list(sequence)
    seq = list(sequence)
    mutation_embed = []
    for mutation in mutations:
        start_position, ref, alt, consequence = mutation
        start_position = start_position + alt_length
        if consequence == snv:
            if embedded_sequence[start_position] == ref:
                embedded_sequence[start_position] = alt
                mutation_embed.append([start_position, start_position, consequence])
        elif consequence == dup or consequence == ins:
            if embedded_sequence[start_position] == ref:
                embedded_sequence[start_position:start_position] = alt
                mutation_embed.append([start_position, start_position+len(alt), consequence])
                alt_length += len(alt)
        elif consequence == delt:
            if len(ref) > len(alt):
                if ''.join(embedded_sequence[start_position:(start_position+len(ref))]) == ref:
                    del embedded_sequence[start_position:start_position+len(ref)]
                    mutation_embed.append([start_position, start_position+len(ref), consequence])
                    alt_length -= len(ref)
            else:
                if ''.join(embedded_sequence[start_position:(start_position+len(alt))]) == ref:
                    del embedded_sequence[start_position:start_position+len(alt)]
                    mutation_embed.append([start_position, start_position+len(alt), consequence])
                    alt_length -= len(alt)
        
    return ''.join(embedded_sequence), mutation_embed
reference_sequence = sequence
embedded_sequence, mutation_embed = embed_mutations(reference_sequence, mutations)
print(len(embedded_sequence))
print(len(reference_sequence))

83535
84761
[[2, 2, 'single_nucleotide_variant'], [3, 3, 'single_nucleotide_variant'], [11, 11, 'single_nucleotide_variant'], [18, 30, 'Duplication'], [36, 36, 'single_nucleotide_variant'], [37, 37, 'single_nucleotide_variant'], [49, 49, 'single_nucleotide_variant'], [59, 59, 'single_nucleotide_variant'], [66, 456, 'Insertion'], [458, 458, 'single_nucleotide_variant'], [476, 476, 'single_nucleotide_variant'], [478, 478, 'single_nucleotide_variant'], [513, 513, 'single_nucleotide_variant'], [539, 543, 'Deletion'], [537, 537, 'single_nucleotide_variant'], [540, 544, 'Insertion'], [544, 544, 'single_nucleotide_variant'], [545, 545, 'single_nucleotide_variant'], [546, 546, 'single_nucleotide_variant'], [547, 547, 'single_nucleotide_variant'], [548, 548, 'single_nucleotide_variant'], [549, 549, 'single_nucleotide_variant'], [550, 550, 'single_nucleotide_variant'], [551, 551, 'single_nucleotide_variant'], [553, 553, 'single_nucleotide_variant'], [556, 556, 'single_nucleotide_variant'], [557,

In [56]:
def split_sequence_with_mutations(sequence, mutations, subsequence_length=512):
    subsequences = []
    for i in range(0, len(sequence), subsequence_length):
        subsequence = sequence[i:i+subsequence_length]
        mutation_vector = [0, 0, 0, 0]  # Initialize vector for SNV, Duplication, Deletion
        for start_position, end_position, consequence in mutations:
            if start_position < i + subsequence_length and start_position >= i:
                if consequence == snv:
                    mutation_vector[0] = 1
                elif consequence == ins:
                    mutation_vector[1] = 1
                elif consequence == dup:
                    mutation_vector[2] = 1
                elif consequence == delt:
                    mutation_vector[3] = 1
        subsequences.append((subsequence, mutation_vector))
    return subsequences
subsequences_with_mutations = split_sequence_with_mutations(embedded_sequence, mutation_embed)
subsequences_without_mutation = split_sequence_with_mutations(sequence, [])
print(len(subsequences_without_mutation))


166


In [55]:
def write_sequences_to_csv(subsequences_without_mutation, subsequences_with_mutations, output_file):
    with open(output_file, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Sequence', 'Mutation_Vector'])  # Write header for reference sequence
        for mutated_subsequence, mutation_vector in subsequences_without_mutation:
            writer.writerow([mutated_subsequence, mutation_vector])
        
        # Write mutated subsequences with mutation vectors
        for mutated_subsequence, mutation_vector in subsequences_with_mutations:
            writer.writerow([mutated_subsequence, mutation_vector])
output_file = "sequences_with_mutations.csv"
write_sequences_to_csv(subsequences_without_mutation, subsequences_with_mutations, output_file)

In [46]:
def split_sequence(sequence, max_length):
    """Split a sequence into subsequences of specified length."""
    subsequences = []
    start_index = 0
    subsequence_length = 0
    last_open_bracket = -1
    last_close_bracket = -1

    for i, char in enumerate(sequence):
        if char == "[":
            last_open_bracket = i
            
        elif char == "]":
            last_close_bracket = i

        subsequence_length += 1
        if subsequence_length > max_length:
            if last_open_bracket != -1:
                subsequences.append(sequence[start_index:last_open_bracket])
                start_index = last_open_bracket
                subsequence_length = i - start_index + 1
                last_open_bracket = -1
                last_close_bracket = -1
            else:
                # If no close bracket found within the max length, start a new subsequence
                subsequences.append(sequence[start_index:i])
                start_index = i
                subsequence_length = 0
                
    # Add the remaining part of the sequence as the last subsequence
    subsequences.append(sequence[start_index:])
    return subsequences 

def find_distance(sequence):
    open_bracket_pos = None
    close_bracket_pos = None
    distance = 0

    for i, char in enumerate(sequence):
        if char == "[":
            open_bracket_pos = i
        elif char == "]":
            close_bracket_pos = i
            if open_bracket_pos is not None:
                distance += close_bracket_pos - open_bracket_pos

    return distance
    

def write_sequences_to_file(reference_sequence, mutated_sequence, subsequence_length, output_file):
    """Write subsequences of reference and mutated sequences along with their labels to a text file."""
    # Split sequences into subsequences
    mutated_subsequences = split_sequence(mutated_sequence, subsequence_length)
    reference_subsequences = []
    i =0
    with open("Embedded.csv", 'w', newline='') as file:
        writer1 = csv.writer(file, lineterminator = '\n')
        writer1.writerow(['Sequence'])
        for seque in mutated_subsequences:
            writer1.writerow([seque])

    for seq in mutated_subsequences:
        distance = find_distance(seq)
        reference_seq = reference_sequence[:(len(seq)-distance)]
        reference_sequence = reference_sequence[(len(seq)-distance):]
        reference_subsequences.append(reference_seq)
        # Write subsequences and labels to the file
    
    with open("Reference.csv", 'w', newline='') as file:
        writer1 = csv.writer(file, lineterminator = '\n')
        writer1.writerow(['Sequence'])
        for seque in reference_subsequences:
            writer1.writerow([seque])
            
    with open(output_file, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile, lineterminator = '\n')
        writer.writerow(['Sequence', 'Label'])
        for i, (ref_subseq, mut_subseq) in enumerate(zip(reference_subsequences, mutated_subsequences), start=1):
            writer.writerow([ref_subseq, 0])
            writer.writerow([mut_subseq, 1])

# Example usage
reference_sequence = sequence  # Your reference sequence
mutated_sequence = embedded_sequence    # Your mutated sequence
subsequence_length = 512    # Length of each subsequence
output_file = "output.csv"  # Output file name

# Write sequences to file
write_sequences_to_file(reference_sequence, mutated_sequence, subsequence_length, output_file)


In [2]:
tokenizer = AutoTokenizer.from_pretrained(
        "zhihan1996/DNABERT-2-117M",
        model_max_length=512,
        padding_side="right",
        use_fast=True,
        trust_remote_code=True,
    )

In [3]:
model = transformers.AutoModelForSequenceClassification.from_pretrained(
        "zhihan1996/DNABERT-2-117M",
        trust_remote_code=True,
    )
# Modify the classification layer for multi-label classification
num_labels = 4  # Number of classes in your dataset
model.classifier = torch.nn.Linear(model.classifier.in_features, num_labels)

Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly ini

In [4]:
@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids = [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in instances]
        labels = [torch.tensor(eval(instance["labels"]), dtype=torch.long) for instance in instances]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=0  # Assuming your labels are 0-indexed
        )
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

In [5]:
class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self,
                 data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer):

        super(SupervisedDataset, self).__init__()

        # load data from the disk
        with open(data_path, "r") as f:
            data = list(csv.reader(f))[1:]
        if len(data[0]) == 2:
            # data is in the format of [text, label]
            texts = [d[0] for d in data]
            labels = [d[1] for d in data]

        output = tokenizer(
            texts,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )

        self.input_ids = output["input_ids"]
        self.attention_mask = output["attention_mask"]
        self.labels = labels
        self.num_labels = len(labels)

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, str]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])

In [6]:
cpu_threads =2

In [7]:
def tokenize_function(examples):
    return tokenizer(examples["Sequence"], truncation=True)

In [19]:
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
# Load dataset and split into train and test sets
train_dataset = load_dataset('csv', data_files={"train": "C:/Users/anjan/mutation_classification/train.csv"})
test_dataset = load_dataset('csv', data_files={"test": "C:/Users/anjan/mutation_classification/test.csv"})

tokenized_train_datasets = train_dataset.map(tokenize_function, batched=True, batch_size=256, remove_columns=["Sequence"])
tokenized_test_datasets = test_dataset.map(tokenize_function, batched=True, batch_size=256, remove_columns=["Sequence"])
tokenized_train_datasets = tokenized_train_datasets.with_format("torch")
tokenized_test_datasets = tokenized_test_datasets.with_format("torch")

train_loader = DataLoader(tokenized_train_datasets["train"] , batch_size=16, collate_fn=data_collator)
test_loader = DataLoader(tokenized_test_datasets["test"] , batch_size=16, collate_fn=data_collator)

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/1283 [00:00<?, ? examples/s]

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import torch.nn.functional as F


# Define optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
epochs = 5

for epoch in range(epochs):
    model.train()
    total_loss = 0
    total_samples = 0
    total_correct = 0
    total_false = 0
    for batch in train_loader:
        batch['labels'] = torch.tensor(batch['labels']).to(device)
        inputs = {k: v.to(device) for k, v in batch.items()}
        labels = inputs.pop('labels')
        logits = model(**inputs).logits
        
        # Ensure logits and labels have the same shape
        if logits.shape != labels.shape:
            raise ValueError(f"Shape mismatch: Logits shape {logits.shape} != Labels shape {labels.shape}")
        
        
        outputs = model(**inputs)
        logits = outputs[0]
        
        # Calculate the loss (use binary cross-entropy since it's multi-label classification)
        loss = F.binary_cross_entropy_with_logits(logits, labels.float())
        print("Loss", loss.item())

        # Now you can proceed with backward pass and optimization
        loss.backward()
        optimizer.step()

        # Get the loss value
        total_loss += loss.item()
        scheduler.step()
        optimizer.zero_grad()
        
        # Calculate total correct predictions
        predicted_labels = (torch.sigmoid(logits) > 0.5).int()
        total_correct += (predicted_labels == labels).sum().item()
        total_false += (predicted_labels != labels).sum().item()
        total_samples += labels.size(0)
    
    avg_train_loss = total_loss / len(train_loader)
    train_accuracy = total_correct / total_samples
    
    
     # Validation loop
    model.eval()
    val_loss = 0
    num_correct = 0
    num_false = 0
    predicted_labels_list = []
    true_labels_list = []
    with torch.no_grad():
        for batch in test_loader:
            batch['labels'] = torch.tensor(batch['labels']).to(device)
            inputs = {k: v.to(device) for k, v in batch.items()}
            labels = inputs.pop('labels')
            logits = model(**inputs).logits
#             inputs = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
            outputs = model(**inputs)
            logits = outputs.logits
            # Calculate the loss
             # Calculate the loss (use binary cross-entropy since it's multi-label classification)
            val_loss = F.binary_cross_entropy_with_logits(logits, labels.float())

            # Calculate total correct predictions
            predicted_labels = (torch.sigmoid(logits) > 0.5).int()
            num_correct += torch.sum(predicted_labels == labels)
            num_false += torch.sum(predicted_labels != labels)

            # Collect predicted and true labels for later calculation of F1-score and recall
            predicted_labels_list.extend(predicted_labels.cpu().numpy())
            true_labels_list.extend(labels.cpu().numpy())

    # Calculate average validation loss
    avg_val_loss = val_loss / len(test_loader)

    # Calculate accuracy
    accuracy = accuracy_score(true_labels_list, predicted_labels_list)

    # Calculate precision
    precision = precision_score(true_labels_list, predicted_labels_list, average='weighted')

    # Calculate recall
    recall = recall_score(true_labels_list, predicted_labels_list, average='weighted')

    # Calculate F1 score
    f1 = f1_score(true_labels_list, predicted_labels_list, average='weighted')
    
    
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Acc: {accuracy:.4f}, f1: {f1}, precision: {precision}, Recall: {recall}")

# Save the trained model

model.save_pretrained("trained_dnabert_model")
tokenizer.save_pretrained("trained_dnabert_model")

  input_ids = [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in instances]
  batch['labels'] = torch.tensor(batch['labels']).to(device)


Loss 0.5522530674934387
Loss 0.5719839334487915
Loss 0.5687587261199951
Loss 0.5753786563873291
Loss 0.5799055099487305
Loss 0.575383186340332
Loss 0.5813748836517334
Loss 0.5677089095115662
Loss 0.5806525945663452
Loss 0.5653905868530273
Loss 0.6611979007720947
Loss 0.6973075866699219
Loss 0.7242401838302612
Loss 0.7408064603805542
Loss 0.6948264837265015
Loss 0.6962414979934692
Loss 0.7035406827926636
Loss 0.6883823275566101
Loss 0.6749321222305298
Loss 0.6714344024658203
Loss 0.710442841053009
Loss 0.651461660861969
Loss 0.6292423605918884
Loss 0.6578718423843384
Loss 0.6497449278831482
Loss 0.6357797384262085
Loss 0.7103989720344543
Loss 0.633319079875946
Loss 0.6415697932243347
Loss 0.6426853537559509
Loss 0.6353136301040649
Loss 0.5926505923271179
Loss 0.5934507846832275
Loss 0.5993504524230957
Loss 0.5983830094337463
Loss 0.602333664894104
Loss 0.5979593992233276
Loss 0.6048659682273865
Loss 0.6275296807289124
Loss 0.6192325353622437
Loss 0.6038755178451538
Loss 0.61651241779327

  batch['labels'] = torch.tensor(batch['labels']).to(device)
  _warn_prf(average, modifier, msg_start, len(result))
  input_ids = [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in instances]
  batch['labels'] = torch.tensor(batch['labels']).to(device)


Epoch 1/5, Train Loss: 0.6143, Val Acc: 0.5152, f1: 0.0, precision: 0.0, Recall: 0.0
Loss 0.5406210422515869
Loss 0.5632015466690063
Loss 0.5616356730461121
Loss 0.569103479385376
Loss 0.5747525095939636
Loss 0.5746734142303467
Loss 0.5814342498779297
Loss 0.5690425038337708
Loss 0.5791658163070679
Loss 0.5775089859962463
Loss 0.661238431930542
Loss 0.6973199248313904
Loss 0.723967432975769
Loss 0.7349513173103333
Loss 0.6887363195419312
Loss 0.7031499147415161
Loss 0.7147533893585205
Loss 0.6839958429336548
Loss 0.6761089563369751
Loss 0.6672990322113037
Loss 0.7014778256416321
Loss 0.6444556713104248
Loss 0.631671130657196
Loss 0.6568597555160522
Loss 0.6448842883110046
Loss 0.6359012126922607
Loss 0.7064165472984314
Loss 0.637397050857544
Loss 0.6427258253097534
Loss 0.6416563391685486
Loss 0.6360622048377991
Loss 0.5921289324760437
Loss 0.5895805358886719
Loss 0.5939860939979553
Loss 0.6002953052520752
Loss 0.6044918894767761
Loss 0.6001992225646973
Loss 0.6016159653663635
Loss 0.6

  batch['labels'] = torch.tensor(batch['labels']).to(device)
  _warn_prf(average, modifier, msg_start, len(result))
  input_ids = [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in instances]
  batch['labels'] = torch.tensor(batch['labels']).to(device)


Epoch 2/5, Train Loss: 0.6138, Val Acc: 0.5152, f1: 0.0, precision: 0.0, Recall: 0.0
Loss 0.5438728928565979
Loss 0.567997932434082
Loss 0.5645102858543396
Loss 0.5687439441680908
Loss 0.5792089700698853
Loss 0.5774344801902771
Loss 0.5839471817016602
Loss 0.5707201957702637
Loss 0.5793180465698242
Loss 0.5735538005828857
Loss 0.6644549369812012
Loss 0.6992572546005249
Loss 0.7253885865211487
Loss 0.7372778058052063
Loss 0.6937853693962097
Loss 0.6992977857589722
Loss 0.7105013132095337
Loss 0.6839407682418823
Loss 0.6695983409881592
Loss 0.6656357049942017
Loss 0.7084629535675049
Loss 0.6524154543876648
Loss 0.6286870241165161
Loss 0.6593900322914124
Loss 0.650754988193512
Loss 0.6357753872871399
Loss 0.7088546752929688
Loss 0.6334181427955627
Loss 0.6379978656768799
Loss 0.6471684575080872
Loss 0.6376996636390686
Loss 0.5943155288696289
Loss 0.5921196937561035
Loss 0.5953654646873474
Loss 0.6008151173591614
Loss 0.5978585481643677
Loss 0.5955621600151062
Loss 0.6000159978866577
Loss 

  batch['labels'] = torch.tensor(batch['labels']).to(device)
  _warn_prf(average, modifier, msg_start, len(result))
  input_ids = [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in instances]
  batch['labels'] = torch.tensor(batch['labels']).to(device)


Epoch 3/5, Train Loss: 0.6140, Val Acc: 0.5152, f1: 0.0, precision: 0.0, Recall: 0.0
Loss 0.5422539710998535
Loss 0.5643614530563354
Loss 0.5674059987068176
Loss 0.5676466822624207
Loss 0.5693956017494202
Loss 0.5760989189147949
Loss 0.588692843914032
Loss 0.5698087215423584
Loss 0.577713131904602
Loss 0.5736908316612244
Loss 0.66201251745224
Loss 0.6976848840713501
Loss 0.725561797618866
Loss 0.7420612573623657
Loss 0.6940734386444092
Loss 0.701422929763794
Loss 0.7091500163078308
Loss 0.6895309686660767
Loss 0.6759456396102905
Loss 0.6679412126541138
Loss 0.7024286389350891
Loss 0.6524664163589478
Loss 0.6290110349655151
Loss 0.6583704948425293
Loss 0.6456541419029236
Loss 0.63542640209198
Loss 0.707120418548584
Loss 0.6353621482849121
Loss 0.6364451050758362
Loss 0.6428018808364868
Loss 0.6363348960876465
Loss 0.5941093564033508
Loss 0.5882591009140015
Loss 0.5993759036064148
Loss 0.5989224314689636
Loss 0.6017229557037354
Loss 0.5960763096809387
Loss 0.5999649167060852
Loss 0.62778

  batch['labels'] = torch.tensor(batch['labels']).to(device)
  _warn_prf(average, modifier, msg_start, len(result))
  input_ids = [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in instances]
  batch['labels'] = torch.tensor(batch['labels']).to(device)


Epoch 4/5, Train Loss: 0.6137, Val Acc: 0.5152, f1: 0.0, precision: 0.0, Recall: 0.0
Loss 0.5433472394943237
Loss 0.5675890445709229
Loss 0.5649101734161377
Loss 0.5694451928138733
Loss 0.5754944682121277
Loss 0.5754842162132263
Loss 0.5826672315597534
Loss 0.5681098103523254
Loss 0.5839060544967651
Loss 0.572968065738678
Loss 0.6621314883232117
Loss 0.6973552107810974
Loss 0.7289947271347046
Loss 0.7395638227462769
Loss 0.6903031468391418
Loss 0.6968054175376892
Loss 0.7132270932197571
Loss 0.6870746612548828
Loss 0.6766112446784973
Loss 0.67062908411026
Loss 0.7095949649810791
Loss 0.6483606100082397
Loss 0.6312641501426697
Loss 0.6568008661270142
Loss 0.6485419869422913
Loss 0.6381157636642456
Loss 0.7067449688911438
Loss 0.638408899307251
Loss 0.6404213905334473
Loss 0.6408998966217041
Loss 0.6356832385063171
Loss 0.5951194167137146
Loss 0.5902162790298462
Loss 0.5939188599586487
Loss 0.6001932621002197
Loss 0.6048240661621094
Loss 0.6027867794036865
Loss 0.6018624305725098
Loss 0.

In [15]:
import torch
from transformers import AutoTokenizer, AutoModel


# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
pred_dataset = load_dataset('csv', data_files="C:/Users/anjan/test.csv")
tokenized_pred_datasets = pred_dataset.map(tokenize_function, batched=True, batch_size=256, remove_columns=["Sequence"])
tokenized_pred_datasets = tokenized_pred_datasets.with_format("torch")

pred_loader = DataLoader(tokenized_pred_datasets["train"] , batch_size=1, collate_fn=data_collator)



model.eval()
val_loss = 0
num_correct = 0
num_false = 0
with torch.no_grad():
    for batch in pred_loader:
        batch['labels'] = torch.tensor(batch['labels']).to(device)
        inputs = {k: v.to(device) for k, v in batch.items()}
        labels = inputs.pop('labels')
        logits = model(**inputs).logits
#             inputs = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}
        outputs = model(**inputs)
        logits = outputs.logits
        # Calculate the loss
         # Calculate the loss (use binary cross-entropy since it's multi-label classification)
        
        # Calculate total correct predictions
        predicted_labels = (torch.sigmoid(logits) > 0.5).int()
        print(predicted_labels)



  input_ids = [torch.tensor(instance["input_ids"], dtype=torch.long) for instance in instances]
  batch['labels'] = torch.tensor(batch['labels']).to(device)


tensor([[0, 0, 0, 0]], dtype=torch.int32)
tensor([[0, 0, 0, 0]], dtype=torch.int32)
