In [1]:
import gdown
train_file_id = '1CX_H9PXRdCEhRx14ifDObqkhuqOXCWZQ'
gdown.download(f'https://drive.google.com/uc?export=download&id={train_file_id}', 'train_punctuation_data.jsonl', quiet=False)
dev_file_id = '1XnJmiji-NCE-goVCxCKfguUnt6mMRoYU'
gdown.download(f'https://drive.google.com/uc?export=download&id={dev_file_id}', 'dev_punctuation_data.jsonl', quiet=False)
test_file_id = '1hqfZq0vKoBqWcUur2on0gqOqHHKOe_Jb'
gdown.download(f'https://drive.google.com/uc?export=download&id={test_file_id}', 'test_punctuation_data.jsonl', quiet=False)

Downloading...
From: https://drive.google.com/uc?export=download&id=1CX_H9PXRdCEhRx14ifDObqkhuqOXCWZQ
To: /kaggle/working/train_punctuation_data.jsonl
100%|██████████| 3.39M/3.39M [00:00<00:00, 66.2MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1XnJmiji-NCE-goVCxCKfguUnt6mMRoYU
To: /kaggle/working/dev_punctuation_data.jsonl
100%|██████████| 463k/463k [00:00<00:00, 121MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1hqfZq0vKoBqWcUur2on0gqOqHHKOe_Jb
To: /kaggle/working/test_punctuation_data.jsonl
100%|██████████| 800k/800k [00:00<00:00, 90.0MB/s]


'test_punctuation_data.jsonl'

In [2]:
import json
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import MBart50TokenizerFast, MBartModel
from torch import nn, optim
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from sklearn.model_selection import train_test_split


def load_tokens_and_labels(filepath):
    tokens, labels = [], []
    with open(filepath, 'r') as f:
        for line in f:
            item = json.loads(line)
            tokens.append(item['tokens_per_sentence'])
            label = item['labels']
            label.insert(0, -99)
            label.append(-99)
            label = [l - 1 for l in label]
            labels.append(label)
    return tokens, labels


# Load the data
train_tokens, train_labels = load_tokens_and_labels("train_punctuation_data.jsonl")
dev_tokens, dev_labels = load_tokens_and_labels("dev_punctuation_data.jsonl")
test_tokens, test_labels = load_tokens_and_labels("test_punctuation_data.jsonl")


# Dataset class
tokenizer = MBart50TokenizerFast.from_pretrained('facebook/mbart-large-50')

class TokenClassificationDataset(Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        encoding = tokenizer(self.tokens[idx],
                             is_split_into_words=True,
                             padding=False,
                             truncation=True,
                             return_tensors=None)
        input_ids = torch.tensor(encoding['input_ids'])
        attention_mask = torch.tensor(encoding['attention_mask'])
        return {'input_ids': input_ids, 'attention_mask': attention_mask}, self.labels[idx]


# Collate Function
def collate_fn(batch):
    filtered_batch = [item for item in batch if item[0]['input_ids'].size(0) <= 60]
    discarded = len(batch) - len(filtered_batch)
    if len(filtered_batch) == 0:
        return None, None, discarded

    max_length = max(item[0]['input_ids'].size(0) for item in filtered_batch)

    input_ids = torch.stack([
        torch.cat([item[0]['input_ids'], torch.zeros(max_length - item[0]['input_ids'].size(0), dtype=torch.long)])
        for item in filtered_batch
    ])
    attention_mask = torch.stack([
        torch.cat([item[0]['attention_mask'], torch.zeros(max_length - item[0]['attention_mask'].size(0), dtype=torch.long)])
        for item in filtered_batch
    ])
    labels = torch.stack([
        torch.cat([torch.tensor(item[1], dtype=torch.long), torch.full((max_length - len(item[1]),), -100, dtype=torch.long)])
        for item in filtered_batch
    ])

    return {'input_ids': input_ids.to(device), 'attention_mask': attention_mask.to(device)}, labels.to(device), discarded



# Model class
class PunctuationModel(nn.Module):
    def __init__(self, model_name, num_labels):
        super(PunctuationModel, self).__init__()
        self.mbart = MBartModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.mbart.config.d_model, num_labels)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.mbart(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        logits = self.classifier(hidden_states)
        return logits


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Datasets and loaders
train_dataset = TokenClassificationDataset(train_tokens, train_labels)
dev_dataset = TokenClassificationDataset(dev_tokens, dev_labels)
test_dataset = TokenClassificationDataset(test_tokens, test_labels)

train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_fn, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=8, collate_fn=collate_fn, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn, shuffle=False)

# Training and Evaluation
model = PunctuationModel('facebook/mbart-large-50', num_labels=6).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()


def evaluate(dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    total_discarded = 0
    running_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            if batch[0] is None:
                total_discarded += batch[2]
                continue

            input_ids = batch[0]['input_ids']
            attention_mask = batch[0]['attention_mask']
            labels = batch[1]

            logits = model(input_ids, attention_mask=attention_mask)

            logits = logits.view(-1, logits.shape[-1])
            labels = labels.view(-1)

            loss = loss_fn(logits, labels)
            running_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            mask = labels != -100
            all_preds.extend(preds[mask].tolist())
            all_labels.extend(labels[mask].tolist())

    avg_loss = running_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    return avg_loss, acc, total_discarded



def train_one_epoch(dataloader):
    model.train()
    running_loss = 0
    all_preds = []
    all_labels = []
    total_discarded = 0

    for batch in dataloader:
        if batch[0] is None:
            total_discarded += batch[2]
            continue

        input_ids = batch[0]['input_ids']
        attention_mask = batch[0]['attention_mask']
        labels = batch[1]
        discarded = batch[2]
        total_discarded += discarded

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask=attention_mask)

        logits = logits.view(-1, logits.shape[-1])
        labels = labels.view(-1)

        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        mask = labels != -100
        all_preds.extend(preds[mask].tolist())
        all_labels.extend(labels[mask].tolist())

    avg_loss = running_loss / len(dataloader)
    acc = accuracy_score(all_labels, all_preds)
    return avg_loss, acc, total_discarded



# Early stopping training loop
def train_for_epochs(train_loader, val_loader, epochs=10, patience=2):
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        train_loss, train_acc, train_discarded = train_one_epoch(train_loader)
        val_loss, val_acc, val_discarded = evaluate(val_loader)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Discarded: {train_discarded}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f} | Discarded: {val_discarded}")
        print('-' * 60)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), 'best_mbart_punctuation_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    print(f"Final Train Accuracy: {train_acc:.4f} | Final Dev Accuracy: {val_acc:.4f}")


#train_for_epochs(train_loader, dev_loader, epochs=10, patience=2)

# Load best model and evaluate on test
# model.load_state_dict(torch.load('best_mbart_punctuation_model.pt'))
# test_acc = evaluate(test_loader)
# print(f"Test Accuracy: {test_acc:.4f}")


2025-05-04 11:51:34.722054: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746359495.025176      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746359495.100682      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tokenizer_config.json:   0%|          | 0.00/531 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

In [3]:
import gdown
model_id = '1uU3tcfKQGomyldf5KmMczFtpzz4IYXb3'
gdown.download(f'https://drive.google.com/uc?export=download&id={model_id}', 'best_test_punctuation_model.pt', quiet=False)

Downloading...
From (original): https://drive.google.com/uc?export=download&id=1uU3tcfKQGomyldf5KmMczFtpzz4IYXb3
From (redirected): https://drive.google.com/uc?export=download&id=1uU3tcfKQGomyldf5KmMczFtpzz4IYXb3&confirm=t&uuid=ea9a0b54-d0a4-4756-9c03-d2368a3c2a4b
To: /kaggle/working/best_test_punctuation_model.pt
100%|██████████| 2.44G/2.44G [00:22<00:00, 108MB/s] 


'best_test_punctuation_model.pt'

In [4]:
state_dict = torch.load('best_test_punctuation_model.pt', weights_only=True)
model.load_state_dict(state_dict)
avg_loss, acc, total_discarded = evaluate(test_loader)
print(f"Test Accuracy: {acc:.4f}")

Test Accuracy: 0.8942


In [5]:
import gdown
file_id = '1FR-MyJRxZH62grW7DihFqbgA8HG0Wnne'
gdown.download(f'https://drive.google.com/uc?export=download&id={file_id}', 'dev_punctuation_data.csv', quiet=False)

file_id = '1e0JBRmo98zKK8Lbfof9QFZ7IlQiYtDC2'
gdown.download(f'https://drive.google.com/uc?export=download&id={file_id}', 'test_punctuation_data.csv', quiet=False)



Downloading...
From: https://drive.google.com/uc?export=download&id=1FR-MyJRxZH62grW7DihFqbgA8HG0Wnne
To: /kaggle/working/dev_punctuation_data.csv
100%|██████████| 227k/227k [00:00<00:00, 84.4MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1e0JBRmo98zKK8Lbfof9QFZ7IlQiYtDC2
To: /kaggle/working/test_punctuation_data.csv
100%|██████████| 356k/356k [00:00<00:00, 115MB/s]


'test_punctuation_data.csv'

In [7]:
import pandas as pd
df = pd.read_csv('dev_punctuation_data.csv')
df.head()

  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()


Unnamed: 0,id,filename,actual_en,asr_decoded,actual_te,baseline_mt_translated
0,1599,9909934339673808373.wav,"The original population hasn't changed at all,...",The original population hasn't changed at all....,"అసలు జనాభా ఏ మాత్రం మారలేదు, వారికి మునుపటి లా...",
1,1608,4582604850545677686.wav,"I lost my sister and her friend, and on my way...","I lost my sister and her friend, and on my way...",నేను నా సోదరిని ఇంకా ఆమె స్నేహితుడిని కోల్పోయా...,
2,1533,13033938513680724611.wav,"Many common formats (APS family of formats, fo...","Many common formats, APS family of formats for...",అనేక సాధారణ ఫార్మాట్‌లు (ఉదాహరణకు ఫార్మాట్‌ల ...,
3,1585,369662424302423610.wav,"Over time, as the new population begins to ada...","Over time, as the population begins to adapt t...","కాలక్రమేణా, కొత్త జనాభా వారి కొత్త వాతావరణానిక...",
4,1594,11667832843141994163.wav,"Some animals, such as elephants and giraffes, ...","Some animals, such as elephants and giraffes, ...","కొన్ని జంతువులు, ఏనుగులు మరియు జిరాఫీలు వంటి జ...",


In [9]:
import os
import torch
import pandas as pd
from tqdm import tqdm
import json  # Still needed if using tokenizer from transformers
# Ensure you have these already defined:
# - tokenizer
# - model
# - device

# Path to your CSV file
input_csv = 'dev_punctuation_data.csv'
output_dir = 'dev_asr_text_punct_features/'  # Folder to save features

os.makedirs(output_dir, exist_ok=True)

# Load CSV
df = pd.read_csv(input_csv, encoding='utf-8',on_bad_lines='skip')

for _, row in tqdm(df.iterrows(), total=len(df), desc="Extracting features"):
    uid = str(row['filename'])
    original_text = str(row['asr_decoded'])
    # print(original_text)
    # print(uid)

    # Tokenize
    encoded = tokenizer(original_text, return_tensors='pt', padding=True, truncation=True)
    input_ids = encoded['input_ids'].to(device)
    attention_mask = encoded['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model.mbart(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state.squeeze(0).cpu()  # [seq_len, hidden_size]

    # Save to a file named by ID
    torch.save(last_hidden_state, f"{output_dir}/{uid}.pt")


Extracting features: 100%|██████████| 365/365 [00:07<00:00, 50.92it/s]


In [10]:
import os
import torch
import pandas as pd
from tqdm import tqdm
import json  # Still needed if using tokenizer from transformers
# Ensure you have these already defined:
# - tokenizer
# - model
# - device

# Path to your CSV file
input_csv = 'test_punctuation_data.csv'
output_dir = 'test_asr_text_punct_features/'  # Folder to save features

os.makedirs(output_dir, exist_ok=True)

# Load CSV
df = pd.read_csv(input_csv, encoding='utf-8',on_bad_lines='skip')

for _, row in tqdm(df.iterrows(), total=len(df), desc="Extracting features"):
    uid = str(row['filename'])
    original_text = str(row['asr_decoded'])
    # print(original_text)
    # print(uid)

    # Tokenize
    encoded = tokenizer(original_text, return_tensors='pt', padding=True, truncation=True)
    input_ids = encoded['input_ids'].to(device)
    attention_mask = encoded['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model.mbart(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state.squeeze(0).cpu()  # [seq_len, hidden_size]

    # Save to a file named by ID
    torch.save(last_hidden_state, f"{output_dir}/{uid}.pt")


Extracting features: 100%|██████████| 554/554 [00:10<00:00, 52.30it/s]


In [14]:
import shutil

shutil.make_archive("dev_asr_text_punct_features", 'zip', "dev_asr_text_punct_features/")

from IPython.display import FileLink
FileLink('dev_asr_text_punct_features.zip')


In [15]:
import shutil

shutil.make_archive("test_asr_text_punct_features", 'zip', "test_asr_text_punct_features/")

from IPython.display import FileLink
FileLink('test_asr_text_punct_features.zip')
