In [1]:
import gdown

csv_file_id = '1-DfbuRqP5xe9eNMjobEDH_cZR1FlwyEw'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'MT_data.csv', quiet=False)
tsv_file_id = '1pBYaqKbkJt66bJlBJGKvTaWVI6Q0yNCC'
gdown.download(f'https://drive.google.com/uc?export=download&id={tsv_file_id}', 'eng_data.tsv', quiet=False)
zip_file_id = '1aNBV7Jgpgm1d21IGnbsIB0E35kZbb0Sc'
gdown.download(f'https://drive.google.com/uc?export=download&id={zip_file_id}', 'punctuation_data.zip', quiet=False)

csv_file_id = '1-I-VFAdiAvFF-7x_Y_0_TwDf0v2kl80l'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'MT_dev.csv', quiet=False)
tsv_file_id = '1M6uGyGSOEW7wKirzi7dlizlRaN9wStwD'
gdown.download(f'https://drive.google.com/uc?export=download&id={tsv_file_id}', 'dev.tsv', quiet=False)
zip_file_id = '1w4FR5I3GUp3cq1I-LIGREWRufLiwrAVU'
gdown.download(f'https://drive.google.com/uc?export=download&id={zip_file_id}', 'punctuation_dev.zip', quiet=False)

csv_file_id = '1-GVAKGJs92uzo8bF1d-q9J_apyqP1ofG'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'MT_test.csv', quiet=False)
tsv_file_id = '1m7JIcj_-SVYASbcbBiT4LoPdgZdqGN4i'
gdown.download(f'https://drive.google.com/uc?export=download&id={tsv_file_id}', 'test.tsv', quiet=False)
zip_file_id = '1Onbil_LmG8MSFPwGPRe5-8iuqFMrRBv5'
gdown.download(f'https://drive.google.com/uc?export=download&id={zip_file_id}', 'punctuation_test.zip', quiet=False)

zip_file_id = '10MA1p4EGfU2zIGIG9j3ZesA73oTVR61-'
gdown.download(f'https://drive.google.com/uc?export=download&id={zip_file_id}', 'punctuation_dev_asr.zip', quiet=False)

zip_file_id = '1NoxnboKl2H_bXn_JeaRIlsDBv_CotfLD'
gdown.download(f'https://drive.google.com/uc?export=download&id={zip_file_id}', 'punctuation_test_asr.zip', quiet=False)

Downloading...
From: https://drive.google.com/uc?export=download&id=1-DfbuRqP5xe9eNMjobEDH_cZR1FlwyEw
To: /kaggle/working/MT_data.csv
100%|██████████| 599k/599k [00:00<00:00, 99.5MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1pBYaqKbkJt66bJlBJGKvTaWVI6Q0yNCC
To: /kaggle/working/eng_data.tsv
100%|██████████| 1.41M/1.41M [00:00<00:00, 100MB/s]
Downloading...
From (original): https://drive.google.com/uc?export=download&id=1aNBV7Jgpgm1d21IGnbsIB0E35kZbb0Sc
From (redirected): https://drive.google.com/uc?export=download&id=1aNBV7Jgpgm1d21IGnbsIB0E35kZbb0Sc&confirm=t&uuid=4313fac5-9464-496d-901f-0dd503f2b2f6
To: /kaggle/working/punctuation_data.zip
100%|██████████| 333M/333M [00:03<00:00, 101MB/s]  
Downloading...
From: https://drive.google.com/uc?export=download&id=1-I-VFAdiAvFF-7x_Y_0_TwDf0v2kl80l
To: /kaggle/working/MT_dev.csv
100%|██████████| 66.2k/66.2k [00:00<00:00, 44.2MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1M6uGyGSOEW7wKir

'punctuation_test_asr.zip'

In [2]:
import os
import zipfile
import torch
import pandas as pd
from transformers import MBartTokenizer

class EnglishTeluguPunctDataset:
    def __init__(self, zip_path, tsv_path, csv_path, tokenizer, max_length=128):
        self.zip_path = zip_path
        self.tsv_path = tsv_path
        self.csv_path = csv_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.skipped = 0

        # Load MT_data.csv (contains both English and Telugu)
        mt_data = pd.read_csv(csv_path)
        self.id_to_english = mt_data.set_index("id")["english"].to_dict()
        self.id_to_telugu = mt_data.set_index("id")["telugu"].to_dict()

        # Load eng_data.tsv (only for filename to ID mapping)
        eng_tsv = pd.read_csv(tsv_path, sep="\t", header=None, dtype=str)
        eng_tsv = eng_tsv[eng_tsv.columns[:2]]
        eng_tsv.columns = ["id", "filename"]
        eng_tsv["id"] = eng_tsv["id"].astype(int)
        self.filename_to_id = eng_tsv.set_index("filename")["id"].to_dict()

        # Read .pt filenames from ZIP
        with zipfile.ZipFile(zip_path, 'r') as zipf:
            self.pt_filenames = [f for f in zipf.namelist() if f.endswith('.pt')]

        # Match each .pt file to a dataset entry
        self.valid_entries = []
        for pt_file in self.pt_filenames:
            base = os.path.basename(pt_file)
            filename_without_extension = base.replace(".pt", "")

            if filename_without_extension in self.filename_to_id:
                id_ = self.filename_to_id[filename_without_extension]
                if id_ in self.id_to_english and id_ in self.id_to_telugu:
                    self.valid_entries.append((pt_file, id_))
                else:
                    self.skipped += 1
            else:
                self.skipped += 1

    def __len__(self):
        return len(self.valid_entries)

    def __getitem__(self, idx):
        pt_file, id_ = self.valid_entries[idx]
    
        # Load punctuation features from .pt file
        with zipfile.ZipFile(self.zip_path, 'r') as zipf:
            with zipf.open(pt_file) as f:
                punct_feats = torch.load(f, weights_only=True)
    
        # Get English and Telugu text
        src_text = self.id_to_english[id_]
        tgt_text = self.id_to_telugu[id_]
    
        # Tokenize without max_length padding
        src_enc = self.tokenizer(src_text, return_tensors="pt", padding=False, truncation=True)
        tgt_enc = self.tokenizer(tgt_text, return_tensors="pt", padding=False, truncation=True)
    
        # Set language codes as first token
        src_lang = "en_XX"
        tgt_lang = "te_IN"
        src_enc.input_ids[0][0] = self.tokenizer.lang_code_to_id[src_lang]
        tgt_enc.input_ids[0][0] = self.tokenizer.lang_code_to_id[tgt_lang]
    
        return {
            "input_ids": src_enc.input_ids.squeeze(0),
            "attention_mask": src_enc.attention_mask.squeeze(0),
            "labels": tgt_enc.input_ids.squeeze(0),
            "punct_feats": punct_feats,
            "src": src_text,
            "tgt": tgt_text,
            "filename": os.path.basename(pt_file).replace(".pt", ""),
            "id": id_
        }

    def get_skipped_count(self):
        return self.skipped


In [3]:
import torch
import torch.nn as nn
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from torch.optim import AdamW
from tqdm import tqdm
from transformers.modeling_outputs import BaseModelOutput
from torch.utils.data import DataLoader



model_name = "facebook/mbart-large-50-many-to-many-mmt"
# Load tokenizer and model
# tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)

# Initialize the tokenizer
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# Initialize the dataset
train_dataset = EnglishTeluguPunctDataset(zip_path="punctuation_data.zip",
                                    tsv_path="eng_data.tsv",
                                    csv_path="MT_data.csv",
                                    tokenizer=tokenizer)
# Create DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Initialize the dataset
dev_dataset = EnglishTeluguPunctDataset(zip_path="punctuation_dev.zip",
                                    tsv_path="dev.tsv",
                                    csv_path="MT_dev.csv",
                                    tokenizer=tokenizer)

# Create DataLoader
dev_dataloader = DataLoader(dev_dataset, batch_size=1, shuffle=True)



# Initialize the dataset
test_dataset = EnglishTeluguPunctDataset(zip_path="punctuation_test.zip",
                                    tsv_path="test.tsv",
                                    csv_path="MT_test.csv",
                                    tokenizer=tokenizer)

# Create DataLoader
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# ASR


dev_asr_dataset = EnglishTeluguPunctDataset(zip_path="punctuation_dev_asr.zip",
                                    tsv_path="dev.tsv",
                                    csv_path="MT_dev.csv",
                                    tokenizer=tokenizer)

# Create DataLoader
dev_asr_dataloader = DataLoader(dev_dataset, batch_size=1, shuffle=True)



# Initialize the dataset
test_asr_dataset = EnglishTeluguPunctDataset(zip_path="punctuation_test_asr.zip",
                                    tsv_path="test.tsv",
                                    csv_path="MT_test.csv",
                                    tokenizer=tokenizer)

# Create DataLoader
test_asr_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)


2025-05-05 11:33:23.132192: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746444803.154487     187 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746444803.161200     187 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

tokenizer.src_lang = SRC_LANG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set model to training mode
model.train()

# Create the DataLoader (assuming your dataset is already created)
# dataset = EnglishTeluguPunctDataset(...)
# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

import torch.nn as nn

from transformers.modeling_outputs import BaseModelOutput
skip=0

import torch.nn as nn

# Assume punct_feats shape: (batch_size, seq_len, feat_dim)
feat_dim = train_dataset[0]["punct_feats"].shape[1]
hidden_dim = model.model.encoder.config.d_model  # mBART hidden size (e.g., 1024)

# Punctuation feature encoder
punct_feat_encoder = nn.Linear(feat_dim, hidden_dim).to(device)

# Add to optimizer
optimizer = AdamW(list(model.parameters()) + list(punct_feat_encoder.parameters()), lr=5e-5)


# for epoch in range(3):  # Example: 3 epochs
#     model.train()
#     for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
#         optimizer.zero_grad()
        
#         input_ids = batch['input_ids'].to(device)
#         attention_mask = batch['attention_mask'].to(device)
#         labels = batch['labels'].to(device)
#         punct_feats = batch['punct_feats'].to(device)  # (batch, seq_len, feat_dim)

#         # Get encoder output from mBART
#         encoder_outputs = model.model.encoder(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             return_dict=True
#         )

#         try:
#             # Add punctuation features to encoder hidden states
#             combined_hidden_state = encoder_outputs.last_hidden_state + punct_feats * 0.1
#         except:
#             skip+=1
#             continue

#         # Wrap into BaseModelOutput
#         encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)

#         decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels)

#         # outputs = model(
#         #     input_ids=None,
#         #     attention_mask=None,
#         #     encoder_outputs=encoder_outputs,
#         #     decoder_input_ids=decoder_input_ids,
#         #     labels=labels
#         # )

#         # Forward pass through the full model
#         outputs = model(
#             input_ids=None,
#             attention_mask=None,
#             decoder_input_ids=decoder_input_ids,  # Use left-shifted target if needed
#             encoder_outputs=encoder_outputs,
#             labels=labels,
#             return_dict=True
#         )

#         # Backpropagation
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()

#     print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

for epoch in range(3):
    model.train()
    punct_feat_encoder.train()

    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        punct_feats = batch['punct_feats'].to(device)  # (batch, seq_len, feat_dim)

        # Forward encoder
        encoder_outputs = model.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        # Pass punctuation features through its encoder
        try:
            encoded_punct_feats = punct_feat_encoder(punct_feats)  # (batch, seq_len, hidden_dim)
            combined_hidden_state = encoder_outputs.last_hidden_state + encoded_punct_feats
        except Exception as e:
            skip += 1
            continue

        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
        decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels)

        outputs = model(
            input_ids=None,
            attention_mask=None,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            labels=labels,
            return_dict=True
        )

        loss = outputs.loss
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")


In [None]:
# Save model and tokenizer to a directory
save_path = "mbart_en_te_text_punc_only"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)


In [None]:
import shutil

# Directory to zip
save_path = "mbart_en_te_text_punc_only"

# Create a zip file (it will be named 'mbart_en_te_text_punc_only.zip')
shutil.make_archive(save_path, 'zip', save_path)


In [None]:
torch.save(model.state_dict(), "mbart_en_te_text_punc_only.pt")


In [None]:
print(translations[:10])

In [4]:
import gdown
model_id = '13djMIYgy7etmFcD7A31x4ucGxgQ-NACc'
gdown.download(f'https://drive.google.com/uc?export=download&id={model_id}', 'mbart_en_te_text_punc_with_encoder_only1.pt', quiet=False)

Downloading...
From (original): https://drive.google.com/uc?export=download&id=13djMIYgy7etmFcD7A31x4ucGxgQ-NACc
From (redirected): https://drive.google.com/uc?export=download&id=13djMIYgy7etmFcD7A31x4ucGxgQ-NACc&confirm=t&uuid=9ab43ff8-66e9-4555-a34b-2a33640db8b6
To: /kaggle/working/mbart_en_te_text_punc_with_encoder_only1.pt
100%|██████████| 2.44G/2.44G [00:13<00:00, 182MB/s] 


'mbart_en_te_text_punc_with_encoder_only1.pt'

In [5]:
state_dict = torch.load('mbart_en_te_text_punc_with_encoder_only1.pt', weights_only=True)
model.load_state_dict(state_dict)


<All keys matched successfully>

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
import pandas as pd
import csv

# Assuming model and tokenizer are already loaded
# model = MBartForConditionalGeneration.from_pretrained(model_path)
# tokenizer = MBart50TokenizerFast.from_pretrained(model_path)

model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "dev_translations.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(dev_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

In [None]:
df= pd.read_csv("dev_translations.csv")
print(df.head())

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
import pandas as pd
import csv

# Assuming model and tokenizer are already loaded
# model = MBartForConditionalGeneration.from_pretrained(model_path)
# tokenizer = MBart50TokenizerFast.from_pretrained(model_path)

model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "text_test_translations_2.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(test_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

In [None]:
df= pd.read_csv("test_translations.csv")
print(df.head())

In [7]:
csv_file_id = '1FR-MyJRxZH62grW7DihFqbgA8HG0Wnne'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'asr_decoded_dev.csv', quiet=False)

#https://drive.google.com/file/d/1e0JBRmo98zKK8Lbfof9QFZ7IlQiYtDC2/view?usp=drive_link
csv_file_id = '1e0JBRmo98zKK8Lbfof9QFZ7IlQiYtDC2'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'asr_decoded_test.csv', quiet=False)

Downloading...
From: https://drive.google.com/uc?export=download&id=1FR-MyJRxZH62grW7DihFqbgA8HG0Wnne
To: /kaggle/working/asr_decoded_dev.csv
100%|██████████| 227k/227k [00:00<00:00, 98.5MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1e0JBRmo98zKK8Lbfof9QFZ7IlQiYtDC2
To: /kaggle/working/asr_decoded_test.csv
100%|██████████| 356k/356k [00:00<00:00, 97.8MB/s]


'asr_decoded_test.csv'

In [8]:
import pandas as pd
import torch

def update_dataset_with_asr(dataset, asr_csv_path, tokenizer, src_lang="en_XX"):
    # Load the ASR CSV
    asr_df = pd.read_csv(asr_csv_path)
    filename_to_asr = dict(zip(asr_df["filename"], asr_df["asr_decoded"]))
    
    # Loop through the dataset and modify the fields
    updated_entries = []
    
    for entry in dataset:
        filename = entry["filename"]  # Get filename from the dataset
        
        if filename in filename_to_asr:
            # Get the corresponding ASR decoded text from CSV
            asr_text = filename_to_asr[filename]
            
            # Tokenize the new ASR text
            enc = tokenizer(asr_text, return_tensors="pt", padding=False, truncation=True)
            enc.input_ids[0][0] = tokenizer.lang_code_to_id[src_lang]  # Add language code to first token
            
            # Update the dataset entry
            entry["input_ids"] = enc.input_ids.squeeze(0)
            entry["attention_mask"] = enc.attention_mask.squeeze(0)
            entry["src"] = asr_text
            
            updated_entries.append(entry)
        else:
            # If no match is found, keep the original entry
            updated_entries.append(entry)
    
    return updated_entries


In [9]:
dev_asr_dataset = update_dataset_with_asr(dev_asr_dataset, "asr_decoded_dev.csv", tokenizer)
test_asr_dataset = update_dataset_with_asr(test_asr_dataset, "asr_decoded_test.csv", tokenizer)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [10]:
print(len(dev_asr_dataset))
print(len(test_asr_dataset))

365
554


In [11]:
dev_asr_dataloader = DataLoader(dev_asr_dataset, batch_size=1, shuffle=False)
test_asr_dataloader = DataLoader(test_asr_dataset, batch_size=1, shuffle=False)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): ReLU()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
import pandas as pd
import csv

# Assuming model and tokenizer are already loaded
# model = MBartForConditionalGeneration.from_pretrained(model_path)
# tokenizer = MBart50TokenizerFast.from_pretrained(model_path)

model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "text_dev_asr_translations.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(dev_asr_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        #print(translated_text)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

In [None]:
df= pd.read_csv("text_dev_asr_translations.csv")
print(df.head())

In [14]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
import pandas as pd
import csv

model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "text_test_asr_translations_3.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(test_asr_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

Processing test data: 100%|██████████| 554/554 [06:38<00:00,  1.39batch/s]

నాగరికత పదం లాటిస్ సైవిస్ కు చెందినది, అంటే నాగరికత, అంటే పౌరునికి చెందినది, అంటే పట్టణం లేదా నగర-రాజ్యకి చెందినది. ఇంకా ఇది ఒక కంపెనీ యొక్క పరిమా
0





In [15]:
df= pd.read_csv("text_test_asr_translations_2.csv")
print(df.head())

                   filename    id  \
0   7266355212113564962.wav  1892   
1   5381110210154713971.wav  1902   
2   1724953769276277810.wav  1944   
3   4705502721902980056.wav  1839   
4  17498257810809617374.wav  1937   

                                                 src  \
0  This is the place the British colonisers took ...   
1  He was initially hospitalised in the James Pag...   
2  These are sometimes-crowded family beaches wit...   
3  Think of the skiing route as of a similar hiki...   
4  However, these plans were rendered obsolete ne...   

                                                 tgt  \
0  బ్రిటిష్ వలసవాదులు తమ సొంత ప్రదేశంగా తీసుకున్న...   
1  అతను ప్రారంభంలో గ్రేట్ యార్మౌత్ లోని జేమ్స్ పే...   
2  ఇవి కొన్నిసార్లు కుటుంబాలతో రద్దీగా ఉండే బీచ్‌...   
3  అదే విధమైన హైకింగ్ రూట్ వలే స్కీయింగ్ రూట్ గుర...   
4  ఏది ఏమయినప్పటికీ, రాత్రికి రాత్రే పథకాన్ని అమల...   

                                          translated  
0  బ్రిటిష్ కాలనీడర్లు తమ సొంతంగా తీసుకున్న స్థల