In [1]:
import gdown

csv_file_id = '1-DfbuRqP5xe9eNMjobEDH_cZR1FlwyEw'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'MT_train.csv', quiet=False)
tsv_file_id = '1pBYaqKbkJt66bJlBJGKvTaWVI6Q0yNCC'
gdown.download(f'https://drive.google.com/uc?export=download&id={tsv_file_id}', 'train.tsv', quiet=False)
mat_file_id = '1UQTiPe34L4VAOZj3OFOIl2FnHsf7y0nf'
gdown.download(f'https://drive.google.com/uc?export=download&id={mat_file_id}', 'punctuation_train.mat', quiet=False)


csv_file_id = '1-I-VFAdiAvFF-7x_Y_0_TwDf0v2kl80l'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'MT_dev.csv', quiet=False)
tsv_file_id = '1M6uGyGSOEW7wKirzi7dlizlRaN9wStwD'
gdown.download(f'https://drive.google.com/uc?export=download&id={tsv_file_id}', 'dev.tsv', quiet=False)
mat_file_id = '1veQfrTT9aTfxRwJ-Sf70QB-ae-kvca8Y'
gdown.download(f'https://drive.google.com/uc?export=download&id={mat_file_id}', 'punctuation_dev.mat', quiet=False)


csv_file_id = '1-GVAKGJs92uzo8bF1d-q9J_apyqP1ofG'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'MT_test.csv', quiet=False)
tsv_file_id = '1m7JIcj_-SVYASbcbBiT4LoPdgZdqGN4i'
gdown.download(f'https://drive.google.com/uc?export=download&id={tsv_file_id}', 'test.tsv', quiet=False)
mat_file_id = '1Z4Ya9JUCEtYXCiYs8LI9d21C4kSvb0KS'
gdown.download(f'https://drive.google.com/uc?export=download&id={mat_file_id}', 'punctuation_test.mat', quiet=False)

Downloading...
From: https://drive.google.com/uc?export=download&id=1-DfbuRqP5xe9eNMjobEDH_cZR1FlwyEw
To: /kaggle/working/MT_train.csv
100%|██████████| 599k/599k [00:00<00:00, 104MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1pBYaqKbkJt66bJlBJGKvTaWVI6Q0yNCC
To: /kaggle/working/train.tsv
100%|██████████| 1.41M/1.41M [00:00<00:00, 140MB/s]
Downloading...
From (original): https://drive.google.com/uc?export=download&id=1UQTiPe34L4VAOZj3OFOIl2FnHsf7y0nf
From (redirected): https://drive.google.com/uc?export=download&id=1UQTiPe34L4VAOZj3OFOIl2FnHsf7y0nf&confirm=t&uuid=60f6b9a7-d8ff-4624-be0b-ffcce2bed0cb
To: /kaggle/working/punctuation_train.mat
100%|██████████| 173M/173M [00:02<00:00, 61.9MB/s] 
Downloading...
From: https://drive.google.com/uc?export=download&id=1-I-VFAdiAvFF-7x_Y_0_TwDf0v2kl80l
To: /kaggle/working/MT_dev.csv
100%|██████████| 66.2k/66.2k [00:00<00:00, 45.6MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1M6uGyGSOEW7wKirzi

'punctuation_test.mat'

In [25]:
import os
import torch
import pandas as pd
import numpy as np
from scipy.io import loadmat
from transformers import MBartTokenizer

class EnglishTeluguPunctDataset:
    def __init__(self, mat_path, tsv_path, csv_path, tokenizer, max_length=128):
        self.mat_path = mat_path
        self.tsv_path = tsv_path
        self.csv_path = csv_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.skipped = 0
        self.skipped_files_names=[]

        # Load .mat file
        mat_data = loadmat(mat_path)
        self.features = mat_data['features'].flatten()
        self.filenames = mat_data['filenames'].flatten()
        
        # Convert filenames to list of strings
        self.filename_to_feat = {}
        for i, fname in enumerate(self.filenames):
            fname_str = str(fname[0]) if isinstance(fname, (np.ndarray, list)) else str(fname)
            self.filename_to_feat[fname_str] = self.features[i]

        # Load MT_data.csv (English + Telugu)
        mt_data = pd.read_csv(csv_path)
        self.id_to_english = mt_data.set_index("id")["english"].to_dict()
        self.id_to_telugu = mt_data.set_index("id")["telugu"].to_dict()

        # # Load TSV file (filename → ID mapping)
        # eng_tsv = pd.read_csv(tsv_path, sep="\t", header=None, dtype=str)
        # eng_tsv = eng_tsv[eng_tsv.columns[:2]]
        # eng_tsv.columns = ["id", "filename"]
        # eng_tsv["id"] = eng_tsv["id"].astype(int)
        # self.filename_to_id = eng_tsv.set_index("filename")["id"].to_dict()

        # Load TSV and filter to only include filenames present in .mat file
        eng_tsv = pd.read_csv(tsv_path, sep="\t", header=None, dtype=str)
        eng_tsv = eng_tsv[eng_tsv.columns[:2]]
        eng_tsv.columns = ["id", "filename"]
        eng_tsv["id"] = eng_tsv["id"].astype(int)
        
        # Only keep rows where filename exists in .mat file
        valid_filenames_set = set(self.filename_to_feat.keys())
        eng_tsv = eng_tsv[eng_tsv["filename"].isin(valid_filenames_set)]
        
        # Build mapping
        self.filename_to_id = eng_tsv.set_index("filename")["id"].to_dict()


        # Build valid entries list
        self.valid_entries = []
        for filename, id_ in self.filename_to_id.items():
            if id_ in self.id_to_english and id_ in self.id_to_telugu and filename in self.filename_to_feat:
                self.valid_entries.append((filename, id_))
            else:
                self.skipped_files_names.append(filename)
                self.skipped += 1

    def __len__(self):
        return len(self.valid_entries)

    def __getitem__(self, idx):
        filename, id_ = self.valid_entries[idx]

        # Load punctuation features
        punct_feats_np = self.filename_to_feat[filename]
        punct_feats = torch.tensor(punct_feats_np, dtype=torch.float32)
        zero_row = torch.zeros((1, punct_feats.size(1)), dtype=torch.float32)
        punct_feats = torch.cat([zero_row, punct_feats, zero_row], dim=0)  

        # Get texts
        src_text = self.id_to_english[id_]
        tgt_text = self.id_to_telugu[id_]

        # Tokenize
        src_enc = self.tokenizer(src_text, return_tensors="pt", padding=False, truncation=True)
        tgt_enc = self.tokenizer(tgt_text, return_tensors="pt", padding=False, truncation=True)

        # Add language code tokens
        src_enc.input_ids[0][0] = self.tokenizer.lang_code_to_id["en_XX"]
        tgt_enc.input_ids[0][0] = self.tokenizer.lang_code_to_id["te_IN"]

        return {
            "input_ids": src_enc.input_ids.squeeze(0),
            "attention_mask": src_enc.attention_mask.squeeze(0),
            "labels": tgt_enc.input_ids.squeeze(0),
            "punct_feats": punct_feats,
            "src": src_text,
            "tgt": tgt_text,
            "filename": filename,
            "id": id_
        }

    def get_skipped_count(self):
        # filename, id_ = self.valid_entries[idx]
        # self.skipped_files_names.append(filename)
        return self.skipped, self.skipped_files_names


# model init

In [4]:
import torch
import torch.nn as nn
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
from torch.optim import AdamW
from tqdm import tqdm
from transformers.modeling_outputs import BaseModelOutput
from torch.utils.data import DataLoader



model_name = "facebook/mbart-large-50-many-to-many-mmt"
# Load tokenizer and model
# tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)

# Initialize the tokenizer
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

2025-05-04 18:10:51.387106: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746382251.561912      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746382251.611514      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

In [26]:
# Initialize the dataset
train_dataset = EnglishTeluguPunctDataset(mat_path="punctuation_train.mat",
                                    tsv_path="train.tsv",
                                    csv_path="MT_train.csv",
                                    tokenizer=tokenizer)
# Create DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Initialize the dataset
dev_dataset = EnglishTeluguPunctDataset(mat_path="punctuation_dev.mat",
                                    tsv_path="dev.tsv",
                                    csv_path="MT_dev.csv",
                                    tokenizer=tokenizer)

# Create DataLoader
dev_dataloader = DataLoader(dev_dataset, batch_size=1, shuffle=True)



# Initialize the dataset
test_dataset = EnglishTeluguPunctDataset(mat_path="punctuation_test.mat",
                                    tsv_path="test.tsv",
                                    csv_path="MT_test.csv",
                                    tokenizer=tokenizer)

# Create DataLoader
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)


In [6]:
print(len(train_dataset))
print(len(dev_dataset))
print(len(test_dataset))

699
113
164


In [7]:
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

tokenizer.src_lang = SRC_LANG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set model to training mode
model.train()

optimizer = AdamW(model.parameters(), lr=1e-5)

import torch.nn as nn

from transformers.modeling_outputs import BaseModelOutput
skip=0

import torch.nn as nn

# Assume punct_feats shape: (batch_size, seq_len, feat_dim)
feat_dim = train_dataset[0]["punct_feats"].shape[1]
hidden_dim = model.model.encoder.config.d_model  # mBART hidden size (e.g., 1024)

# Punctuation feature encoder
punct_feat_encoder = nn.Linear(feat_dim, hidden_dim).to(device)

# Add to optimizer
optimizer = AdamW(list(model.parameters()) + list(punct_feat_encoder.parameters()), lr=5e-5)


for epoch in range(3):
    model.train()
    punct_feat_encoder.train()

    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        punct_feats = batch['punct_feats'].to(device)  # (batch, seq_len, feat_dim)

        # Forward encoder
        encoder_outputs = model.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        # Pass punctuation features through its encoder
        try:
            encoded_punct_feats = punct_feat_encoder(punct_feats)  # (batch, seq_len, hidden_dim)
            combined_hidden_state = encoder_outputs.last_hidden_state + encoded_punct_feats *0.1
        except Exception as e:
            skip += 1
            continue

        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
        decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels)

        outputs = model(
            input_ids=None,
            attention_mask=None,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            labels=labels,
            return_dict=True
        )

        loss = outputs.loss
        if torch.isnan(loss):
            continue
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1: 100%|██████████| 699/699 [03:41<00:00,  3.16it/s]


Epoch 1, Loss: 4.8411


Epoch 2: 100%|██████████| 699/699 [03:41<00:00,  3.16it/s]


Epoch 2, Loss: 4.2137


Epoch 3: 100%|██████████| 699/699 [03:41<00:00,  3.16it/s]


Epoch 3, Loss: 0.7976


In [8]:
torch.save(model.state_dict(), "mbart_en_te_speech_punct_with_encoder.pt")

In [9]:
from IPython.display import FileLink

# Replace with your actual .pt file name
FileLink('mbart_en_te_speech_punct_with_encoder.pt')

In [11]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
import pandas as pd
import csv

# Assuming model and tokenizer are already loaded
# model = MBartForConditionalGeneration.from_pretrained(model_path)
# tokenizer = MBart50TokenizerFast.from_pretrained(model_path)

model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "dev_translations_speech_encoder.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(dev_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        # print(translated_text)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

Processing test data: 100%|██████████| 113/113 [01:19<00:00,  1.43batch/s]

వేల సంవత్సరాల క్రితం, ఆర్రి స్టార్క్యూస్ అనే మనిషి సోలార్ సిస్టమ్ సూర్యుని చుట్టూ మారినారని చెప్పాడు.
0





In [79]:
import pandas as pd
df= pd.read_csv("dev_translations_speech_encoder.csv")
print(df.head())

                   filename    id  \
0  10749355711908235873.wav  1615   
1  10354232920158081925.wav  1560   
2  16799907084202296752.wav  1658   
3  15553404816205838782.wav  1592   
4  14780052696556980882.wav  1624   

                                                 src  \
0  Thousands of years ago, a man called Aristarch...   
1  As soon as you get out of the current, swimmin...   
2  With only eighteen medals available a day, a n...   
3  When you went abroad at first, people were pro...   
4  Inland waterways can be a good theme to base a...   

                                                 tgt  \
0  వేల సంవత్సరాల క్రితం అరిస్టార్కస్ అనే వ్యక్తి ...   
1  మీరు కరెంట్ నుంచి బయటకు వచ్చిన వెంటనే, తిరిగి ...   
2  రోజుకు కేవలం పద్దెనిమిది మెడల్స్ మాత్రమే అందుబ...   
3  మీరు మొదట విదేశాలకు వెళ్ళినప్పుడు, కొత్త దేశంల...   
4  సెలవుదినం గడపడానికి ఇన్ల్యాండ్ వాటర్​వేస్ మంచి...   

                                          translated  
0  వేల సంవత్సరాల క్రితం, ఆర్రి స్టార్క్యూస్ అనే 

In [78]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch
import pandas as pd
import csv

# Assuming model and tokenizer are already loaded
# model = MBartForConditionalGeneration.from_pretrained(model_path)
# tokenizer = MBart50TokenizerFast.from_pretrained(model_path)

model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "test_translations_speech_encoder.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(test_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

Processing test data:   5%|▍         | 8/164 [00:05<01:55,  1.35batch/s]


KeyboardInterrupt: 

In [81]:
batch = next(iter(train_dataloader))
print(batch['src'])  # This will give you the 'src' field from the first batch


['The village of Haldarsvík offer views of the nearby island Eysturoy and has an unusual octagonal church.']


In [82]:
csv_file_id = '1FR-MyJRxZH62grW7DihFqbgA8HG0Wnne'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'asr_decoded_dev.csv', quiet=False)

Downloading...
From: https://drive.google.com/uc?export=download&id=1FR-MyJRxZH62grW7DihFqbgA8HG0Wnne
To: /kaggle/working/asr_decoded_dev.csv
100%|██████████| 227k/227k [00:00<00:00, 82.8MB/s]


'asr_decoded_dev.csv'

In [83]:
import pandas as pd
import torch

def update_dataset_with_asr(dataset, asr_csv_path, tokenizer, src_lang="en_XX"):
    # Load the ASR CSV
    asr_df = pd.read_csv(asr_csv_path)
    filename_to_asr = dict(zip(asr_df["filename"], asr_df["asr_decoded"]))
    
    # Loop through the dataset and modify the fields
    updated_entries = []
    
    for entry in dataset:
        filename = entry["filename"]  # Get filename from the dataset
        
        if filename in filename_to_asr:
            # Get the corresponding ASR decoded text from CSV
            asr_text = filename_to_asr[filename]
            
            # Tokenize the new ASR text
            enc = tokenizer(asr_text, return_tensors="pt", padding=False, truncation=True)
            enc.input_ids[0][0] = tokenizer.lang_code_to_id[src_lang]  # Add language code to first token
            
            # Update the dataset entry
            entry["input_ids"] = enc.input_ids.squeeze(0)
            entry["attention_mask"] = enc.attention_mask.squeeze(0)
            entry["src"] = asr_text
            
            updated_entries.append(entry)
        else:
            # If no match is found, keep the original entry
            updated_entries.append(entry)
    
    return updated_entries


In [66]:
dev_asr_dataset = update_dataset_with_asr(dev_dataset, "asr_decoded_dev.csv", tokenizer)

In [67]:
print(len(dev_asr_dataset))

113


In [58]:
csv_file_id = '1e0JBRmo98zKK8Lbfof9QFZ7IlQiYtDC2'
gdown.download(f'https://drive.google.com/uc?export=download&id={csv_file_id}', 'asr_decoded_test.csv', quiet=False)

Downloading...
From: https://drive.google.com/uc?export=download&id=1e0JBRmo98zKK8Lbfof9QFZ7IlQiYtDC2
To: /kaggle/working/asr_decoded_test.csv
100%|██████████| 356k/356k [00:00<00:00, 92.9MB/s]


'asr_decoded_test.csv'

In [68]:
test_asr_dataset = update_dataset_with_asr(test_dataset, "asr_decoded_test.csv", tokenizer)

In [69]:
print(len(test_asr_dataset))

164


In [70]:
dev_asr_dataloader = DataLoader(dev_asr_dataset, batch_size=1, shuffle=True)
test_asr_dataloader = DataLoader(test_asr_dataset, batch_size=1, shuffle=True)

In [71]:

model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "dev_asr_translations_speech_encoder.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(dev_asr_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        # print(translated_text)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

Processing test data: 100%|██████████| 113/113 [01:18<00:00,  1.44batch/s]

ప్రధాన స్థానిక బీర్ ఒకటి. ఇది సంక్లిష్ట బీర్ కాదు కానీ సంతోషంగా మరియు రిఫ్రీట్ చేస్తారు. ఇతర స్థానిక బీర్ మన్టా అని పిలుస్తారు.
0





In [72]:
import pandas as pd
df= pd.read_csv("dev_asr_translations_speech_encoder.csv")
print(df.head())

                   filename    id  \
0  15158676295442294624.wav  1590   
1  12952903060751652532.wav  1590   
2  16131823300806444840.wav  1544   
3  12470893547277455431.wav  1557   
4  16359228487623086121.wav  1557   

                                                 src  \
0  The main local beer is number one. It is not a...   
1  The main local beer is number one. It is not a...   
2  Insects were the first animals to take to the ...   
3  After seeing the horrors and atrocities of war...   
4  After seeing the horrors and atrocities of war...   

                                                 tgt  \
0  "ఇక్కడి ప్రధాన స్థానిక బీర్ 'Number One', ఇది ...   
1  "ఇక్కడి ప్రధాన స్థానిక బీర్ 'Number One', ఇది ...   
2  కీటకాలు గాలిలోకి తీసుకువెళ్ళే మొదటి జంతువులు. ...   
3  రెండవ ప్రపంచ యుద్ధసమయంలో జరిగిన ఘోరాలు, ఘోరాలు...   
4  రెండవ ప్రపంచ యుద్ధసమయంలో జరిగిన ఘోరాలు, ఘోరాలు...   

                                          translated  
0  ప్రధాన స్థానిక బీర్ ఒకటి. ఇది సంక్లిష్ట బీర్ 

In [84]:
model.eval()

# Define the source and target languages
SRC_LANG = "en_XX"
TGT_LANG = "te_IN"

# Set the tokenizer source language
tokenizer.src_lang = SRC_LANG

# Placeholder for storing translations
translations = []

# Prepare CSV file (write header once)
output_csv = "test_asr_translations_speech_encoder.csv"
with open(output_csv, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.DictWriter(f, fieldnames=["filename", "id", "src", "tgt", "translated"])
    writer.writeheader()

    skip2=0
    # Iterate over the dev_loader
    for batch in tqdm(test_asr_dataloader, desc="Processing test data", unit="batch"):
        # Assuming 'input_ids' and 'attention_mask' are present in the batch
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        punct_feats = batch['punct_feats'].to(device)
        
        # Unwrap values from lists or tensors
        filename = batch['filename'][0] if isinstance(batch['filename'], list) else batch['filename']
        fileid = batch['id'].item() if isinstance(batch['id'], torch.Tensor) else batch['id']
        src_text = batch['src'][0] if isinstance(batch['src'], list) else batch['src']
        tgt_text = batch['tgt'][0] if isinstance(batch['tgt'], list) else batch['tgt']

        
        # Set the language token for the source
        input_ids[0][0] = tokenizer.lang_code_to_id[SRC_LANG]
        
        # Run the encoder manually
        with torch.no_grad():
            encoder_outputs = model.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True
            )
    
        try:
            # Add punctuation features to encoder hidden states
            combined_hidden_state = encoder_outputs.last_hidden_state
        except:
            skip2+=1
            # print('skipping :(')
            continue
    
        # Wrap into BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=combined_hidden_state)
    
        # Generate translation from decoder with forced beginning token for Telugu
        generated_tokens = model.generate(
            input_ids=None,
            encoder_outputs=encoder_outputs,
            attention_mask=None,
            forced_bos_token_id=tokenizer.lang_code_to_id[TGT_LANG],
            max_length=50
        )
        
        # Decode the generated tokens into text
        translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        # print(translated_text)
        translations.append(translated_text)
    
        writer.writerow({
                    "filename": filename,
                    "id": fileid,
                    "src": src_text,
                    "tgt": tgt_text,
                    "translated": translated_text
                })
    
    
    # Print out translations for the entire dev set
    for translated_text in translations:
        print(translated_text)
        break
    
    print(skip2)

Processing test data: 100%|██████████| 164/164 [01:56<00:00,  1.41batch/s]

అదృష్టవశాత్తు, డ్రైవర్ ప్రవర్తనను 100 శాతం సురక్షితతో అంచనా వేయడం కష్టం.
0



