# Data-Processing

In [1]:
import pandas as pd
import numpy as np
import os

df = pd.read_csv(r"/kaggle/input/clotho-dataset/clotho_captions_development.csv")
audio_folder = r'/kaggle/input/clotho-dataset/clotho_audio_development'

def find_audio_path(file_name):
    path = os.path.join(audio_folder, file_name)
    if os.path.exists(path):
        return path
    print("not find",path)

df['audio_path'] = df['file_name'].apply(find_audio_path)
df.head()

Unnamed: 0,file_name,caption_1,caption_2,caption_3,caption_4,caption_5,audio_path
0,DistortedAMRadionoise.wav,A muddled noise of broken channel of the TV,A television blares the rhythm of a static TV.,Loud television static dips in and out of focus,The loud buzz of static constantly changes pit...,heavy static and the beginnings of a signal on...,/kaggle/input/clotho-dataset/clotho_audio_deve...
1,PaperParchmentRustling.wav,A person is turning a map over and over.,A person is very carefully rapping a gift for ...,A person is very carefully wrapping a gift for...,"He sighed as he turned the pages of the book, ...","papers are being turned, stopped, then turned ...",/kaggle/input/clotho-dataset/clotho_audio_deve...
2,03WhalesSlowingDown.wav,Several barnyard animals mooing in a barn whil...,"The vocalization of several whales, along with...","Underwater, large numbers of shrimp clicking a...",Whales sing to one another over the flowing wa...,wales sing to one another with water flowing i...,/kaggle/input/clotho-dataset/clotho_audio_deve...
3,Ropetiedtoboatinport.wav,An office chair is squeaking as someone bends ...,Popping and squeaking gradually tapers off to ...,Someone is opening a creaky door slowly while ...,Squeaking and popping followed by gradual popp...,an office chair is squeaking as someone leans ...,/kaggle/input/clotho-dataset/clotho_audio_deve...
4,carpenterbee.wav,A flying bee is buzzing loudly around an objec...,An annoying fly is buzzing loudly and consiste...,An insect buzzing in the foreground as birds c...,"An insect trapped in a spider web struggles, b...","Outdoors, insect trapped in a spider web and t...",/kaggle/input/clotho-dataset/clotho_audio_deve...


# Pre-Processing Data

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
import random

class AudioCaptionDataset(Dataset):
    def __init__(self, audio_paths, captions, tokenizer, sr=32000, n_mels=128, max_len=30, augment = True):
        self.audio_paths = audio_paths
        self.captions = captions
        self.tokenizer = tokenizer
        self.sr = sr #16000 # 42100
        self.n_mels = n_mels
        self.max_len = max_len
        self.augment = augment
        self.mel_transform = T.MelSpectrogram(
            sample_rate=self.sr,
            n_fft=1024, # 4096/ 2048
            win_length=1024, #4096  2048
            hop_length= 320, #512 256
            n_mels=self.n_mels,
            f_min=20,
            f_max=self.sr // 2
        )
        self.spec_augment = torch.nn.Sequential(
            T.FrequencyMasking(freq_mask_param=15),
            T.TimeMasking(time_mask_param=50)
        )
        
    # def add_noise(self, waveform, noise_level=0.005):
    #     noise = torch.randn_like(waveform) * noise_level
    #     return waveform + noise

    # def pitch_shift(self, waveform, n_steps=2):
    #     return F.pitch_shift(waveform, self.sr, n_steps)

    # def time_stretch(self, waveform, rate):
    #     # Chuyển về spectrogram để time-stretch không cần sox
    #     spectrogram = T.Spectrogram()(waveform)
    #     stretcher = T.TimeStretch(hop_length=None, n_freq=spectrogram.shape[1])
    #     stretched_spec = stretcher(spectrogram, rate)
    #     # Convert ngược lại waveform nếu cần
        
    # def waveform_augment(self, waveform):
    #     if random.random() < 0.5:
    #         waveform = self.add_noise(waveform, noise_level=random.uniform(0.003, 0.008))
    #     if random.random() < 0.3:
    #         waveform = self.pitch_shift(waveform, n_steps=random.randint(-1, 1))
    #     if random.random() < 0.3:
    #         waveform = self.time_stretch(waveform, rate=random.uniform(0.95, 1.05))
    #     return waveform

    def __getitem__(self, idx):
        # Load audio
        waveform, orig_sr = torchaudio.load(self.audio_paths[idx])
        if orig_sr != self.sr:
            waveform = T.Resample(orig_sr, self.sr)(waveform)
            
        max_len_audio = self.sr * 25  # 10 giây
        waveform = waveform[:, :max_len_audio]
        if waveform.shape[1] < max_len_audio:
            pad = max_len_audio - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad))
            
        # if self.augment:
        #     waveform = self.waveform_augment(waveform)
            
        mel_spec = self.mel_transform(waveform)
        log_mel = torch.log1p(mel_spec)
        #log_mel = (log_mel - log_mel.mean()) / log_mel.std()  # Normalize
        log_mel = (log_mel - log_mel.mean(dim=(-2, -1), keepdim=True)) / (log_mel.std(dim=(-2, -1), keepdim=True) + 1e-9)
        log_mel = log_mel.repeat(3, 1, 1) #mel_spec only 1 channel but resnet have 3 channel
        
        if self.augment:
            log_mel = self.spec_augment(log_mel)     
            
        # Tokenize caption
        caption_text = self.captions[idx]
        caption_tokens = self.tokenizer.encode(caption_text, max_len=self.max_len)
        caption_tensor = torch.tensor(caption_tokens, dtype=torch.long)

        return log_mel, caption_tensor  # Now return tensor #log_mel.squeeze(0)

    def __len__(self):
        return len(self.audio_paths)


In [3]:
class SimpleTokenizer:
    def __init__(self, texts, min_freq=1):
        from collections import Counter

        self.special_tokens = ['<pad>', '<bos>', '<eos>', '<unk>']
        counter = Counter()

        for text in texts:
            words = text.lower().strip().split()
            counter.update(words)

        self.word2idx = {token: idx for idx, token in enumerate(self.special_tokens)}
        for word, freq in counter.items():
            if freq >= min_freq and word not in self.word2idx:
                self.word2idx[word] = len(self.word2idx)

        self.idx2word = {idx: word for word, idx in self.word2idx.items()}
        self.vocab_size = len(self.word2idx)
        self.pad_token_id = self.word2idx['<pad>']

    def encode(self, text, max_len=30):
        tokens = [self.word2idx.get(word, self.word2idx['<unk>']) for word in text.lower().strip().split()]
        tokens = [self.word2idx['<bos>']] + tokens + [self.word2idx['<eos>']]
        if len(tokens) < max_len:
            tokens += [self.word2idx['<pad>']] * (max_len - len(tokens))
        else:
            tokens = tokens[:max_len]
        return tokens

    def decode(self, tokens):
        words = [self.idx2word.get(idx, '<unk>') for idx in tokens]
        return ' '.join(words)


## Resize length caption

In [4]:
# # import torch.nn.functional as F

# # def collate_fn(batch):
# #     # batch: list of tuples (mel, caption)
# #     mels, captions = zip(*batch)
    
# #     # Padding audio (log_mel)
# #     max_len = max(mel.shape[-1] for mel in mels)  # lấy max time_steps trong batch

# #     padded_mels = []
# #     for mel in mels:
# #         pad_size = max_len - mel.shape[-1]
# #         padded_mel = F.pad(mel, (0, pad_size))  # Pad cuối chiều time_steps
# #         padded_mels.append(padded_mel)

# #     mels_tensor = torch.stack(padded_mels)
# #     captions_tensor = torch.stack(captions)

# #     return mels_tensor, captions_tensor

# from torch.nn.utils.rnn import pad_sequence

# def collate_fn(batch):
#     audios, captions = zip(*batch)

#     audios = torch.stack(audios)

#     # Pad captions về cùng độ dài
#     captions = [torch.tensor(cap, dtype=torch.long) for cap in captions]
#     captions_padded = pad_sequence(captions, batch_first=True, padding_value=tokenizer.pad_token_id)

#     return audios, captions_padded

In [5]:
from sklearn.model_selection import train_test_split

data = []
for _, row in df.iterrows():
    audio_path = row["audio_path"]
    captions = [row[f"caption_{i}"] for i in range(1, 6)]

    for caption in captions:
        data.append((audio_path, caption))
audio_paths, captions = zip(*data)

train_audio_paths, val_audio_paths, train_captions, val_captions = train_test_split(audio_paths, captions, 
                                                                                    test_size=0.1, random_state=42)

In [6]:
all_captions = list(train_captions) + list(val_captions)
tokenizer = SimpleTokenizer(all_captions)

train_dataset = AudioCaptionDataset(train_audio_paths, train_captions, tokenizer, augment = True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) #, collate_fn=collate_fn

val_dataset = AudioCaptionDataset(val_audio_paths, val_captions, tokenizer, augment = False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=4) #, collate_fn=collate_fn

# Early Stopping

In [7]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Save model when validation loss decreases.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} → {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Encoder

In [8]:
import torchvision.models as models
import torch.nn as nn

# class AudioEncoder_resnet50(nn.Module):
#     def __init__(self, output_dim=512):  # output_dim bạn vẫn có thể muốn giảm sau ResNet
#         super().__init__()
#         resnet = models.resnet50(pretrained=True)
#         resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)  # chỉnh cho 3-channel log-mel
#         modules = list(resnet.children())[:-1]  # bỏ fc
#         self.resnet = nn.Sequential(*modules)
#         self.fc = nn.Linear(2048, output_dim)  # map từ 2048 → output_dim

#     def forward(self, x):
#         x = self.resnet(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
#         return x

class AudioEncoder_50(nn.Module):
    def __init__(self, output_dim=2048):
        super().__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(2048, output_dim)

    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

import torchvision.models as models
from torch import nn

class AudioEncoder_18(nn.Module):
    def __init__(self, output_dim=512):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-2])  # Bỏ FC
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, output_dim)
        #self.projection = nn.Linear(512, 256) 

    def forward(self, x):
        # Input x: (batch_size, 1, n_mels, time) → thêm 1 channel
        x = self.features(x)
        x = self.pool(x).squeeze(-1).squeeze(-1)
        x = self.fc(x)
        return x  # (batch_size, output_dim)


# Decoder LSTM

In [9]:
class CaptionDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, decoder_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim + decoder_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, features, captions):
        # captions: (batch_size, seq_len)
        embeddings = self.embed(captions)
        
        features_expanded = features.unsqueeze(1).repeat(1, embeddings.size(1), 1)
        inputs = torch.cat((features_expanded, embeddings), dim=2)
        #inputs = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        
        hiddens, _ = self.lstm(inputs)
        outputs = self.fc(hiddens)
        return outputs  # (batch_size, seq_len+1, vocab_size)


# Combined Encode + Decoder

In [10]:
class AudioCaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, audio_spec, captions):
        features = self.encoder(audio_spec)
        outputs = self.decoder(features, captions)
        return outputs

# Ver2 combine Lstm Attention

In [11]:
# import torch
# import torch.nn as nn
# import torchvision.models as models

# class AudioEncoder_Attention(nn.Module):
#     def __init__(self):
#         super().__init__()
#         resnet = models.resnet50(pretrained=True)
#         modules = list(resnet.children())[:-2]  # Bỏ global avgpool + fc
#         self.resnet = nn.Sequential(*modules)
#         self.adaptive_pool = nn.AdaptiveAvgPool2d((10, 10))  # Có thể thêm pooling nhẹ nếu cần giới hạn time_steps

#     def forward(self, x):
#         x = self.resnet(x)  # (batch_size, 2048, H, W), thường H=W=~7
#         x = self.adaptive_pool(x)  # (batch_size, 2048, 10, 10) để giảm kích thước sequence
#         x = x.permute(0, 2, 3, 1)  # (batch_size, 10, 10, 2048)
#         x = x.view(x.size(0), -1, 2048)  # (batch_size, 100, 2048) --> sequence cho attention
#         return x  # output shape: (batch_size, time_steps, encoder_dim)


In [12]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class BahdanauAttention(nn.Module):
#     def __init__(self, encoder_dim, decoder_dim, attention_dim):
#         super(BahdanauAttention, self).__init__()
#         self.encoder_att = nn.Linear(encoder_dim, attention_dim)
#         self.decoder_att = nn.Linear(decoder_dim, attention_dim)
#         self.full_att = nn.Linear(attention_dim, 1)

#     def forward(self, encoder_outputs, decoder_hidden):
#         # encoder_outputs: (batch_size, time_steps, encoder_dim)
#         # decoder_hidden: (batch_size, decoder_dim)

#         # Add time dimension to decoder hidden
#         decoder_hidden = decoder_hidden.unsqueeze(1)  # (batch_size, 1, decoder_dim)
        
#         att1 = self.encoder_att(encoder_outputs)  # (batch_size, time_steps, attention_dim)
#         att2 = self.decoder_att(decoder_hidden)#.unsqueeze(1)  # (batch_size, 1, attention_dim)
        
#         att = torch.tanh(att1 + att2)  # broadcast add
#         e = self.full_att(att).squeeze(2)  # (batch_size, time_steps)
        
#         alpha = F.softmax(e, dim=1)  # attention weights (batch_size, time_steps)
        
#         context = (encoder_outputs * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        
#         return context, alpha


In [13]:
# class CaptionDecoder_Attention(nn.Module):
#     def __init__(self, vocab_size, embed_dim, decoder_dim, hidden_dim, attention_module, num_layers=1):
#         super().__init__()
#         self.embed = nn.Embedding(vocab_size, embed_dim)
#         self.lstm = nn.LSTM(embed_dim + decoder_dim, hidden_dim, num_layers, batch_first=True)
#         self.fc = nn.Linear(hidden_dim, vocab_size)
#         self.attention = attention_module  # <-- attention module đã định nghĩa ở trên
        
#     def forward(self, encoder_outputs, captions):
#         batch_size, seq_len = captions.shape
#         embeddings = self.embed(captions)  # (batch_size, seq_len, embed_dim)
        
#         h = torch.zeros(1, batch_size, self.lstm.hidden_size).to(captions.device)
#         c = torch.zeros(1, batch_size, self.lstm.hidden_size).to(captions.device)

#         outputs = []
#         for t in range(seq_len):
#             embedding_t = embeddings[:, t, :]  # (batch, embed_dim)

#             if t == 0:
#                 # Khởi tạo attention bằng 0
#                 context = torch.zeros(batch_size, encoder_outputs.shape[2]).to(captions.device)
#             else:
#                 # decoder_hidden = h[-1] 
#                 # context, _ = self.attention(encoder_outputs, decoder_hidden)
#                 context, _ = self.attention(encoder_outputs, h.squeeze(0))  # (batch, encoder_dim)

#             lstm_input = torch.cat((embedding_t, context), dim=1).unsqueeze(1)
#             output, (h, c) = self.lstm(lstm_input, (h, c))
#             output_vocab = self.fc(output.squeeze(1))
#             outputs.append(output_vocab)

#         outputs = torch.stack(outputs, dim=1)  # (batch_size, seq_len, vocab_size)
#         return outputs


# Tranning Model

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from transformers import get_scheduler

def train_one_epoch(model, train_loader, optimizer, criterion, device, lr_scheduler):
    model.train()
    total_loss, total_samples = 0, 0
    loop = tqdm(train_loader, leave=True)
    for mel, captions in loop:
        mel = mel.to(device)              
        captions = captions.to(device)  

        optimizer.zero_grad()
        outputs = model(mel, captions[:, :-1])  
        loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()

        batch_size = mel.size(0)  
        total_loss += loss.item() * batch_size  
        total_samples += batch_size

    avg_loss = total_loss / total_samples
    return avg_loss

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss, total_samples = 0, 0
    with torch.no_grad():
        loop = tqdm(val_loader, leave=True)
        for mel, captions in loop:
            mel = mel.to(device)              
            captions = captions.to(device)  
        
            outputs = model(mel, captions[:, :-1])   # Teacher Forcing: input là caption trừ token cuối
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))     
            
            batch_size = mel.size(0)  # số sample trong batch
            total_loss += loss.item() * batch_size  # cộng dồn tổng loss theo sample
            total_samples += batch_size
            
    avg_loss = total_loss / total_samples
    return avg_loss 

In [15]:
vocab_size = tokenizer.vocab_size  
encoder = AudioEncoder_50(output_dim=2048)
#decoder = CaptionDecoder(vocab_size, embed_dim=256, encoder_dim=512, hidden_dim=512)
decoder = CaptionDecoder(vocab_size, embed_dim=512, decoder_dim=2048, hidden_dim=1024)

# encoder = AudioEncoder_Attention()
# attention = BahdanauAttention(encoder_dim=2048, decoder_dim=1024, attention_dim = 512)
# decoder = CaptionDecoder_Attention(vocab_size, embed_dim=512, decoder_dim=2048, hidden_dim=1024, attention_module=attention)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AudioCaptioningModel(encoder, decoder)
model = nn.DataParallel(model).to(device)

optimizer = optim.AdamW(model.parameters(), lr=2e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

num_epochs = 30
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

early_stopper = EarlyStopping(patience=3, verbose=True)


for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader,  optimizer, criterion, device, lr_scheduler)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}")

    val_loss = validate(model, val_loader, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Loss: {val_loss:.4f}")
    
    early_stopper(val_loss, model)
    if early_stopper.early_stop:
        print("Early stopping")
        break

torch.save(model.state_dict(), "audio_caption_model_for_inference.pth")
print("Model saved for inference!")

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 173MB/s]
100%|██████████| 540/540 [12:14<00:00,  1.36s/it]


Epoch [1/30], Train Loss: 5.2195


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [1/30], Val Loss: 5.0863
Validation loss decreased (inf → 5.086333). Saving model ...


100%|██████████| 540/540 [12:43<00:00,  1.41s/it]


Epoch [2/30], Train Loss: 4.2925


100%|██████████| 60/60 [01:13<00:00,  1.22s/it]


Epoch [2/30], Val Loss: 4.1916
Validation loss decreased (5.086333 → 4.191587). Saving model ...


100%|██████████| 540/540 [12:27<00:00,  1.38s/it]


Epoch [3/30], Train Loss: 3.8754


100%|██████████| 60/60 [01:04<00:00,  1.08s/it]


Epoch [3/30], Val Loss: 3.9509
Validation loss decreased (4.191587 → 3.950950). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [4/30], Train Loss: 3.5677


100%|██████████| 60/60 [01:05<00:00,  1.10s/it]


Epoch [4/30], Val Loss: 3.7890
Validation loss decreased (3.950950 → 3.789046). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [5/30], Train Loss: 3.3160


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [5/30], Val Loss: 3.7219
Validation loss decreased (3.789046 → 3.721883). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [6/30], Train Loss: 3.0913


100%|██████████| 60/60 [01:05<00:00,  1.08s/it]


Epoch [6/30], Val Loss: 3.6243
Validation loss decreased (3.721883 → 3.624317). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [7/30], Train Loss: 2.8848


100%|██████████| 60/60 [01:05<00:00,  1.08s/it]


Epoch [7/30], Val Loss: 3.5715
Validation loss decreased (3.624317 → 3.571480). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [8/30], Train Loss: 2.6953


100%|██████████| 60/60 [01:04<00:00,  1.08s/it]


Epoch [8/30], Val Loss: 3.6524
EarlyStopping counter: 1 out of 3


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [9/30], Train Loss: 2.5191


100%|██████████| 60/60 [01:06<00:00,  1.10s/it]


Epoch [9/30], Val Loss: 3.4974
Validation loss decreased (3.571480 → 3.497426). Saving model ...


100%|██████████| 540/540 [12:18<00:00,  1.37s/it]


Epoch [10/30], Train Loss: 2.3572


100%|██████████| 60/60 [01:04<00:00,  1.08s/it]


Epoch [10/30], Val Loss: 3.4769
Validation loss decreased (3.497426 → 3.476945). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [11/30], Train Loss: 2.2083


100%|██████████| 60/60 [01:05<00:00,  1.08s/it]


Epoch [11/30], Val Loss: 3.4776
EarlyStopping counter: 1 out of 3


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [12/30], Train Loss: 2.0695


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [12/30], Val Loss: 3.4667
Validation loss decreased (3.476945 → 3.466704). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [13/30], Train Loss: 1.9446


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [13/30], Val Loss: 3.4638
Validation loss decreased (3.466704 → 3.463775). Saving model ...


100%|██████████| 540/540 [12:18<00:00,  1.37s/it]


Epoch [14/30], Train Loss: 1.8284


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [14/30], Val Loss: 3.4567
Validation loss decreased (3.463775 → 3.456737). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [15/30], Train Loss: 1.7236


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [15/30], Val Loss: 3.4544
Validation loss decreased (3.456737 → 3.454409). Saving model ...


100%|██████████| 540/540 [12:19<00:00,  1.37s/it]


Epoch [16/30], Train Loss: 1.6265


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [16/30], Val Loss: 3.5048
EarlyStopping counter: 1 out of 3


100%|██████████| 540/540 [12:20<00:00,  1.37s/it]


Epoch [17/30], Train Loss: 1.5413


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [17/30], Val Loss: 3.4644
EarlyStopping counter: 2 out of 3


100%|██████████| 540/540 [12:20<00:00,  1.37s/it]


Epoch [18/30], Train Loss: 1.4574


100%|██████████| 60/60 [01:05<00:00,  1.09s/it]


Epoch [18/30], Val Loss: 3.4669
EarlyStopping counter: 3 out of 3
Early stopping
Model saved for inference!


In [16]:
torch.save(model.state_dict(), "audio_caption_model_for_inference.pth")
print("Model saved for inference!")
# Lưu token như này để test model local
torch.save(tokenizer, "tokenizer.pt")
print("Tokenizer saved for inference!")

Model saved for inference!
Tokenizer saved for inference!


In [17]:
# Lưu token này để deploy tiện hơn
import json

with open("word2idx.json", "w") as f:
    json.dump(tokenizer.word2idx, f)

In [None]:
# import pickle

# with open('tokenizer.pkl', 'wb') as f:
#     pickle.dump(tokenizer, f)

# with open('tokenizer.pkl', 'rb') as f:
#     tokenizer = pickle.load(f)

# # tokenizer là instance của một class kế thừa nn.Module hoặc có hỗ trợ state_dict
# torch.save(tokenizer.state_dict(), "tokenizer.pt")

# from model import SimpleTokenizer  # phải có class này định nghĩa đúng

# # 1. Tạo lại tokenizer đúng cấu trúc ban đầu (nhưng chưa có weights)
# tokenizer = SimpleTokenizer(...)  # truyền các config nếu cần

# # 2. Load weights
# tokenizer.load_state_dict(torch.load("tokenizer.pt"))


In [None]:
# Save word2idx khi muốn mang đi deploy ở chõ khác tải cái token này
# with open('word2idx.pkl', 'wb') as f:
#     pickle.dump(tokenizer.word2idx, f)

# Test Model

In [None]:
import torchaudio
import torchaudio.transforms as T
import torch.nn.functional as F

def preprocess_audio(audio_path, sr=16000, n_mels=64, max_audio_len=25):
    waveform, orig_sr = torchaudio.load(audio_path)
    if orig_sr != sr:
        waveform = T.Resample(orig_sr, sr)(waveform)
    
    max_len_audio = sr * max_audio_len
    waveform = waveform[:, :max_len_audio]
    if waveform.shape[1] < max_len_audio:
        pad = max_len_audio - waveform.shape[1]
        waveform = F.pad(waveform, (0, pad))

    mel_transform = T.MelSpectrogram(sample_rate=sr, n_mels=n_mels)
    mel_spec = mel_transform(waveform)
    log_mel = torch.log1p(mel_spec)
    log_mel = (log_mel - log_mel.mean()) / log_mel.std()
    log_mel = log_mel.repeat(3, 1, 1)  # (3, n_mels, time)
    
    return log_mel


In [None]:
def generate_caption(model, mel, tokenizer, max_len=20):
    model.eval()
    device = next(model.parameters()).device
    
    with torch.no_grad():
        features = model.encoder(mel.to(device))
        
        # Bắt đầu với token <SOS>
        input_word = torch.tensor([[tokenizer.word2idx['<bos>']]]).to(device)
        captions = []
        hidden = None

        for _ in range(max_len):
            embedding = model.decoder.embed(input_word)
            lstm_input = torch.cat((features.unsqueeze(1), embedding), dim=2)
            output, hidden = model.decoder.lstm(lstm_input, hidden)
            output_vocab = model.decoder.fc(output.squeeze(1))
            predicted = output_vocab.argmax(1)
            
            word = tokenizer.idx2word[predicted.item()]
            if word == '<eos>':
                break
            captions.append(word)
            input_word = predicted.unsqueeze(0)
    
    return ' '.join(captions)


In [None]:
# Load raw state_dict
state_dict = torch.load("/kaggle/working/audio_caption_model_for_inference.pth", map_location=device)

# Hàm để xóa 'module.' nếu có
def remove_module_prefix(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("module."):
            new_key = k[7:]  # bỏ 'module.' đi
        else:
            new_key = k
        new_state_dict[new_key] = v
    return new_state_dict

new_state_dict = remove_module_prefix(state_dict)

model.load_state_dict(new_state_dict)
model.eval()


In [None]:
# mel = preprocess_audio("/kaggle/input/clotho-dataset/clotho_audio_development/0111Ambulance.wav") 
# mel = mel.unsqueeze(0).to(device)  # thêm batch dimension

# with torch.no_grad():
#     features = model.encoder(mel)
#     caption = generate_caption(model.decoder, features, tokenizer)
#     print(caption)

In [None]:
mel = preprocess_audio("/kaggle/input/clotho-dataset/clotho_audio_development/00264hillcreek1.wav")
mel = mel.unsqueeze(0).to(device)  # thêm batch dimension
caption = generate_caption(model, mel, tokenizer)
print("Generated Caption:", caption)