In [1]:
# Install required packages including torchcodec
!pip install -q datasets==3.6.0 \
sentence-transformers==4.1.0 \
soundfile==0.13.1 \
speechbrain==1.0.3 \
torchaudio==2.6.0 \
transformers==4.52.4 \
torchcodec


In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
from datasets import load_dataset, Dataset as HFDataset  # Fixed import
from huggingface_hub import login
from speechbrain.inference import EncoderClassifier
from google.colab import userdata
import numpy as np
from collections import defaultdict

# ===========================
# Setup device
# ===========================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# ===========================
# Hugging Face Authentication
# ===========================
from huggingface_hub import login
login(token = "put your hugging face token here")

Using device: cuda


In [3]:
print("Loading pretrained encoders...")
language_id = EncoderClassifier.from_hparams(
    source="TalTechNLP/voxlingua107-epaca-tdnn",
    savedir="tmp_lang",
    run_opts={"device": device}
)

speaker_id = EncoderClassifier.from_hparams(
    source="speechbrain/spkrec-ecapa-voxceleb",
    savedir="tmp_spk",
    run_opts={"device": device}
)

Loading pretrained encoders...


  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)


In [4]:
# Unfreeze encoder parameters for end-to-end training
for param in language_id.parameters():
    param.requires_grad = True

for param in speaker_id.parameters():
    param.requires_grad = True

In [5]:
@torch.no_grad()
def get_lang_embedding(waveforms: torch.Tensor) -> torch.Tensor:
    """Extract language embeddings from waveforms"""
    emb = language_id.encode_batch(waveforms.to(device))
    return emb.squeeze(1).to(device)

@torch.no_grad()
def get_spk_embeddings(waveforms: torch.Tensor, sr: int = 16000) -> torch.Tensor:
    """Extract speaker embeddings from waveforms"""
    emb = speaker_id.encode_batch(waveforms.to(device))
    return emb.squeeze(1).to(device)

In [6]:
# ===========================
# CrossFeaturePrefixTuner
# ===========================
class CrossFeaturePrefixTuner(nn.Module):
    def __init__(self, dim: int, num_heads: int = 4, prefix_len: int = 5):
        super().__init__()
        self.prefix_k = nn.Parameter(torch.randn(prefix_len, dim) * 0.02)
        self.prefix_v = nn.Parameter(torch.randn(prefix_len, dim) * 0.02)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

    def forward(self, query: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
        B = kv.size(0)
        prefix_k = self.prefix_k.unsqueeze(0).expand(B, -1, -1).to(kv.device)
        prefix_v = self.prefix_v.unsqueeze(0).expand(B, -1, -1).to(kv.device)

        k = torch.cat([prefix_k, kv], dim=1)
        v = torch.cat([prefix_v, kv], dim=1)

        out, _ = self.attn(query, k, v)
        return out

# ===========================
# Sinusoidal Positional Encoding
# ===========================
def sinusoidal_pos_enc(seq_len: int, dim: int, device) -> torch.Tensor:
    pe = torch.zeros(seq_len, dim, device=device)
    position = torch.arange(0, seq_len, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, dim, 2, device=device) * (-math.log(10000.0) / dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)

In [7]:
# ===========================
# DecoderLSTMFCUp
# ===========================
class DecoderLSTMFCUp(nn.Module):
    def __init__(self, embed_dim: int, n_mels: int = 80, hidden_dim: int = 512,
                 lstm_layers: int = 2, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_mels = n_mels

        self.pre = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )

        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout if lstm_layers > 1 else 0.0,
            bidirectional=False,
        )

        self.head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, n_mels),
        )

        self.out_ln = nn.LayerNorm(n_mels)

    def forward(self, query: torch.Tensor, out_len: int):
        B, seq_len, _ = query.size()

        # Adjust query length to match output length
        if seq_len < out_len:
            repeat_factor = (out_len + seq_len - 1) // seq_len
            query = query.repeat(1, repeat_factor, 1)
            query = query[:, :out_len, :]
        elif seq_len > out_len:
            query = query[:, :out_len, :]

        x = self.pre(query)
        lstm_out, _ = self.lstm(x)
        out = self.head(lstm_out)
        out = self.out_ln(out)
        return out

# ===========================
# AAM Softmax Loss (Additive Angular Margin)
# ===========================
class AAMSoftmax(nn.Module):
    def __init__(self, n_classes, feat_dim, s=30.0, m=0.2):
        super(AAMSoftmax, self).__init__()
        self.n_classes = n_classes
        self.feat_dim = feat_dim
        self.s = s  # Scale factor
        self.m = m  # Margin

        # Weight normalization
        self.weight = nn.Parameter(torch.randn(n_classes, feat_dim))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x, labels):
        # Normalize features and weights
        x_norm = F.normalize(x, p=2, dim=1)
        w_norm = F.normalize(self.weight, p=2, dim=1)

        # Compute cosine similarity
        cosine = F.linear(x_norm, w_norm)

        # Add margin to the target angle
        phi = cosine - self.m

        # One-hot encoding
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1).long(), 1)

        # Output with margin
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return F.cross_entropy(output, labels)


In [8]:
# ===========================
# LASPA Model
# ===========================
class LASPA(nn.Module):
    def __init__(self, proj_dim: int = 192, num_heads: int = 4, prefix_len: int = 5,
                 n_mels: int = 80, hidden_dim: int = 512, num_speakers: int = 1000,
                 num_languages: int = 10):
        super().__init__()

        # Projection layers
        self.lang_proj_layer = nn.Linear(256, proj_dim)
        self.spk_proj_layer = nn.Linear(192, proj_dim)  # Speaker embedding is 192-dim

        # Cross-feature prefix tuners
        self.spk2lang = CrossFeaturePrefixTuner(proj_dim, num_heads=num_heads, prefix_len=prefix_len)
        self.lang2spk = CrossFeaturePrefixTuner(proj_dim, num_heads=num_heads, prefix_len=prefix_len)

        # Decoder
        self.decoder = DecoderLSTMFCUp(embed_dim=2 * proj_dim, n_mels=n_mels, hidden_dim=hidden_dim)

        # Mel spectrogram transform
        self.mel_tf = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,
            n_fft=400,
            win_length=400,
            hop_length=160,
            f_min=0.0,
            f_max=8000.0,
            n_mels=n_mels,
            center=True,
            pad_mode="reflect",
            power=2.0,
            norm=None,
            mel_scale="htk"
        ).to(device)  # Move to device

        # Classification heads for AAM Softmax and NLL losses
        self.speaker_classifier = AAMSoftmax(num_speakers, proj_dim, s=30.0, m=0.2)
        self.language_classifier = nn.Linear(proj_dim, num_languages)

    @staticmethod
    def _log_mel(mel: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        return torch.log(mel + eps)

    def forward(self, waveforms: torch.Tensor, speaker_ids=None, language_ids=None):
        B = waveforms.size(0)

        # Get embeddings
        with torch.no_grad():
            spk_raw = get_spk_embeddings(waveforms)
            lang_raw = get_lang_embedding(waveforms)

        # Project embeddings
        spk_proj = self.spk_proj_layer(spk_raw)
        lang_proj = self.lang_proj_layer(lang_raw)

        # Prepare for cross-attention
        spk_q = spk_proj.unsqueeze(1)
        lang_q = lang_proj.unsqueeze(1)

        # Cross-feature fusion
        spk2lang_out = self.spk2lang(spk_q, lang_q)
        lang2spk_out = self.lang2spk(lang_q, spk_q)

        # Concatenate fused features
        fused = torch.cat([spk2lang_out, lang2spk_out], dim=-1)

        # Generate mel spectrogram target
        with torch.no_grad():
            mel = self.mel_tf(waveforms.to(device))
            mel = self._log_mel(mel)
            mel_target = mel.transpose(1, 2)

        out_len = mel_target.size(1)

        # Decode to mel spectrogram
        mel_hat = self.decoder(fused, out_len=out_len)

        # Compute classification logits for losses
        speaker_loss = None
        language_logits = None

        if speaker_ids is not None:
            speaker_loss = self.speaker_classifier(spk_proj, speaker_ids)

        if language_ids is not None:
            language_logits = self.language_classifier(lang_proj)

        return mel_hat, mel_target, spk_proj, lang_proj, speaker_loss, language_logits


In [9]:
# ===========================
# Loss Functions
# ===========================
def compute_lmse(recon: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Mean Squared Error Loss for mel spectrogram reconstruction"""
    return F.mse_loss(recon, target)

def compute_lmapc(spk_emb: torch.Tensor, lang_emb: torch.Tensor) -> torch.Tensor:
    """Mean Absolute Pearson Correlation loss for disentanglement"""
    spk = spk_emb - spk_emb.mean(dim=-1, keepdim=True)
    lang = lang_emb - lang_emb.mean(dim=-1, keepdim=True)

    num = (spk * lang).sum(dim=-1)
    den = torch.norm(spk, dim=-1) * torch.norm(lang, dim=-1) + 1e-8
    corr = num / den

    return corr.abs().mean()  # Return absolute correlation (want to minimize)

class LASPALoss(nn.Module):
    def __init__(self, alpha: float = 1.0, beta: float = 1.0, gamma: float = 1.0, delta: float = 1.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha))  # MSE weight
        self.beta = nn.Parameter(torch.tensor(beta))    # MAPC weight
        self.gamma = nn.Parameter(torch.tensor(gamma))  # AAM Softmax weight
        self.delta = nn.Parameter(torch.tensor(delta))  # NLL weight

    def forward(self, mel_hat, mel_target, spk_emb, lang_emb, speaker_loss, language_logits, language_ids):
        # MSE Loss for reconstruction
        lmse = compute_lmse(mel_hat, mel_target)

        # MAPC Loss for disentanglement
        lmapc = compute_lmapc(spk_emb, lang_emb)

        # Initialize classification losses
        aam_loss = torch.tensor(0.0, device=mel_hat.device)
        nll_loss = torch.tensor(0.0, device=mel_hat.device)

        # AAM Softmax Loss (already computed in model)
        if speaker_loss is not None:
            aam_loss = speaker_loss

        # NLL Loss for language classification
        if language_logits is not None and language_ids is not None:
            nll_loss = F.cross_entropy(language_logits, language_ids)

        # Total weighted loss
        total = self.alpha * lmse + self.beta * lmapc + self.gamma * aam_loss + self.delta * nll_loss

        metrics = {
            "LMSE": float(lmse.detach().cpu()),
            "LMAPC": float(lmapc.detach().cpu()),
            "AAM": float(aam_loss.detach().cpu()),
            "NLL": float(nll_loss.detach().cpu())
        }

        return total, metrics


In [10]:
from datasets import load_dataset
import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm

# -------------------------------
# Kathbath Dataset (non-streaming)
# -------------------------------
class KathbathDataset(Dataset):
    def __init__(self, languages=["sanskrit"], split="train", max_samples_per_lang=None, target_sr=16000):
        super().__init__()
        self.languages = languages
        self.split = split
        self.target_sr = target_sr
        self.data = []
        
        # Global mappings
        self.speaker_to_id = {}
        self.language_to_id = {lang: i for i, lang in enumerate(languages)}
        self.lang_speakers = {lang: set() for lang in languages}

        # Load datasets fully
        for lang in languages:
            try:
                ds = load_dataset("ai4bharat/Kathbath", lang, split=split)
            except Exception as e:
                print(f"Warning: Could not load {lang} dataset: {e}")
                continue

            print(f"Loading {lang} samples...")
            for i, item in enumerate(tqdm(ds, ncols=100)):
                if max_samples_per_lang and i >= max_samples_per_lang:
                    break

                audio = item.get("audio_filepath", None)
                if audio is None or "array" not in audio:
                    continue

                speaker = item.get("speaker_id", f"{lang}_unk")
                if speaker not in self.speaker_to_id:
                    self.speaker_to_id[speaker] = len(self.speaker_to_id)
                self.lang_speakers[lang].add(speaker)

                waveform = torch.tensor(audio["array"]).float()
                sr = audio["sampling_rate"]
                if sr != self.target_sr:
                    waveform = torchaudio.transforms.Resample(sr, self.target_sr)(waveform)

                self.data.append({
                    "waveform": waveform,
                    "speaker_id": self.speaker_to_id[speaker],
                    "language_id": self.language_to_id[lang]
                })

        if len(self.data) == 0:
            print("Warning: No valid samples found!")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# -------------------------------
# Collate function
# -------------------------------
def collate_fn(batch):
    max_len = max(x["waveform"].shape[0] for x in batch)
    waveforms, speakers, langs = [], [], []

    for x in batch:
        wf = x["waveform"]
        if wf.shape[0] < max_len:
            wf = F.pad(wf, (0, max_len - wf.shape[0]))
        else:
            wf = wf[:max_len]
        waveforms.append(wf)
        speakers.append(x["speaker_id"])
        langs.append(x["language_id"])

    return {
        "waveform": torch.stack(waveforms),
        "speaker_id": torch.tensor(speakers, dtype=torch.long),
        "language_id": torch.tensor(langs, dtype=torch.long),
    }

# -------------------------------
# Setup DataLoader
# -------------------------------
def setup_dataloader(langs=["sanskrit"], split="train", batch_size=4, max_samples_per_lang=None):
    dataset = KathbathDataset(languages=langs, split=split, max_samples_per_lang=max_samples_per_lang)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=0)

    print(f"\nTotal samples: {len(dataset)}")
    print(f"Total Speakers: {len(dataset.speaker_to_id)}, Total Languages: {len(dataset.language_to_id)}")
    for lang in langs:
        print(f"  {lang}: {len(dataset.lang_speakers[lang])} speakers")

    return dataloader, len(dataset.speaker_to_id), len(dataset.language_to_id)



In [11]:
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

@torch.no_grad()
def visualize_embeddings(model, dataloader, max_batches=5):
    all_spk, all_lang, all_spk_id, all_lang_id = [], [], [], []
    for i, batch in enumerate(dataloader):
        if i >= max_batches: break
        wave = batch["waveform"].to(device)
        spk_id = batch["speaker_id"].cpu().numpy()
        lang_id = batch["language_id"].cpu().numpy()

        _, _, spk_emb, lang_emb, _, _ = model(wave)

        all_spk.append(spk_emb.cpu())
        all_lang.append(lang_emb.cpu())
        all_spk_id.extend(spk_id)
        all_lang_id.extend(lang_id)

    spk_mat = torch.cat(all_spk).numpy()
    lang_mat = torch.cat(all_lang).numpy()

    tsne = TSNE(n_components=2, random_state=42)
    spk_2d = tsne.fit_transform(spk_mat)
    lang_2d = tsne.fit_transform(lang_mat)

    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.scatter(spk_2d[:,0], spk_2d[:,1], c=all_spk_id, cmap="tab20", s=10)
    plt.title("Speaker Embeddings")
    plt.subplot(1,2,2)
    plt.scatter(lang_2d[:,0], lang_2d[:,1], c=all_lang_id, cmap="tab10", s=10)
    plt.title("Language Embeddings")
    plt.show()


In [12]:
def plot_mel_comparison(mel_true, mel_recon, idx=0):
    fig, axes = plt.subplots(1,2, figsize=(10,4))
    axes[0].imshow(mel_true[idx].T.cpu().numpy(), aspect="auto", origin="lower")
    axes[0].set_title("Ground Truth Mel")
    axes[1].imshow(mel_recon[idx].T.detach().cpu().numpy(), aspect="auto", origin="lower")
    axes[1].set_title("Reconstructed Mel")
    plt.show()

In [13]:
# ===========================
# Training Loop
# ===========================
def training_loop(model, loss_fn, optimizer, scheduler, train_loader, num_epochs=5):
    print("Starting training...")

    for epoch in range(1, num_epochs + 1):
        model.train()

        total_loss = 0.0
        total_lmse = 0.0
        total_lmapc = 0.0
        total_aam = 0.0
        total_nll = 0.0
        num_batches = 0

        for batch_idx, batch in enumerate(train_loader):
            try:
                # Move data to device
                waveforms = batch['waveform'].to(device)
                speaker_ids = batch['speaker_id'].to(device)
                language_ids = batch['language_id'].to(device)

                # Zero gradients
                optimizer.zero_grad()

                # Forward pass
                mel_hat, mel_target, spk_emb, lang_emb, speaker_loss, language_logits = model(
                    waveforms, speaker_ids, language_ids
                )

                # Compute loss
                loss, metrics = loss_fn(
                    mel_hat, mel_target, spk_emb, lang_emb,
                    speaker_loss, language_logits, language_ids
                )

                # Backward pass
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Update weights
                optimizer.step()

                # Accumulate metrics
                total_loss += loss.item()
                total_lmse += metrics["LMSE"]
                total_lmapc += metrics["LMAPC"]
                total_aam += metrics["AAM"]
                total_nll += metrics["NLL"]
                num_batches += 1

                # Print progress
                if batch_idx % 5 == 0:
                        print(f"Epoch {epoch}, Batch {batch_idx+1}")  # <-- remove len(train_loader)
                        print(f"  Loss: {loss.item():.4f}")
                        print(f"  LMSE: {metrics['LMSE']:.4f}, LMAPC: {metrics['LMAPC']:.4f}")
                        print(f"  AAM: {metrics['AAM']:.4f}, NLL: {metrics['NLL']:.4f}")

                # Early stopping for quick testing (remove this for full training)
                if batch_idx >= 10:  # Just 10 batches per epoch for initial testing
                    break

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                import traceback
                traceback.print_exc()
                continue

        if num_batches > 0:
            avg_loss = total_loss / num_batches
            avg_lmse = total_lmse / num_batches
            avg_lmapc = total_lmapc / num_batches
            avg_aam = total_aam / num_batches
            avg_nll = total_nll / num_batches

            # Update learning rate
            scheduler.step(avg_loss)

            print(f"\n{'='*50}")
            print(f"Epoch {epoch} Summary:")
            print(f"  Avg Loss: {avg_loss:.4f}")
            print(f"  Avg LMSE: {avg_lmse:.4f}, Avg LMAPC: {avg_lmapc:.4f}")
            print(f"  Avg AAM: {avg_aam:.4f}, Avg NLL: {avg_nll:.4f}")
            print(f"  Loss weights - α: {loss_fn.alpha.item():.4f}, β: {loss_fn.beta.item():.4f}")
            print(f"                γ: {loss_fn.gamma.item():.4f}, δ: {loss_fn.delta.item():.4f}")
            print(f"  Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
            print(f"{'='*50}\n")

            # Save checkpoint every epoch
            save_checkpoint(epoch, model, optimizer, loss_fn, avg_loss)

    return model


In [14]:
# ===========================
# Utility Functions
# ===========================
def save_checkpoint(epoch, model, optimizer, loss_fn, avg_loss):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_fn_state_dict': loss_fn.state_dict(),
        'avg_loss': avg_loss,
    }
    torch.save(checkpoint, f'laspa_checkpoint_epoch_{epoch}.pth')
    print(f"Saved checkpoint at epoch {epoch}")

def save_final_model(model):
    torch.save(model.state_dict(), 'laspa_final_model.pth')
    print("Final model saved as 'laspa_final_model.pth'")

In [15]:
def setup_model_and_optimizer(num_speakers, num_languages, device=device, lr=1e-3):
    # Initialize model
    model = LASPA(num_speakers=num_speakers, num_languages=num_languages).to(device)

    # Loss function
    loss_fn = LASPALoss(alpha=1.0, beta=1.0, gamma=1.0, delta=1.0).to(device)

    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # Scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

    return model, loss_fn, optimizer, scheduler


In [16]:
# ===========================
# Main Training Function
# ===========================
def train_laspa():
    """Main function to orchestrate the entire training process"""
    try:
        # Setup data
        train_loader, num_speakers, num_languages = setup_dataloader()

        # Setup model and training components
        model, loss_fn, optimizer, scheduler = setup_model_and_optimizer(
            num_speakers, num_languages
        )

        # Print model info
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"\nModel Information:")
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")

        # Train the model
        model = training_loop(model, loss_fn, optimizer, scheduler, train_loader, num_epochs=5)

        # Save final model
        save_final_model(model)

        print("\nTraining completed successfully!")

    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()


In [17]:
# ===========================
# Execute Training
# ===========================
if __name__ == "__main__":
    train_laspa()

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/19 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

Loading sanskrit samples...


100%|████████████████████████████████████████████████████████| 26840/26840 [02:53<00:00, 154.31it/s]



Total samples: 26840
Total Speakers: 182, Total Languages: 1
  sanskrit: 182 speakers

Model Information:
Total parameters: 5,126,065
Trainable parameters: 5,126,065
Starting training...
Epoch 1, Batch 1
  Loss: 98.1440
  LMSE: 82.4190, LMAPC: 0.0689
  AAM: 15.6561, NLL: 0.0000
Epoch 1, Batch 6
  Loss: 103.7782
  LMSE: 91.9563, LMAPC: 0.0545
  AAM: 11.7674, NLL: 0.0000
Epoch 1, Batch 11
  Loss: 81.3836
  LMSE: 72.3105, LMAPC: 0.0989
  AAM: 8.9742, NLL: 0.0000

Epoch 1 Summary:
  Avg Loss: 101.1970
  Avg LMSE: 89.5106, Avg LMAPC: 0.0707
  Avg AAM: 11.6157, Avg NLL: 0.0000
  Loss weights - α: 1.0000, β: 1.0000
                γ: 1.0000, δ: 1.0000
  Learning rate: 0.001000

Saved checkpoint at epoch 1
Epoch 2, Batch 1
  Loss: 93.5003
  LMSE: 78.1308, LMAPC: 0.0165
  AAM: 15.3530, NLL: 0.0000
Epoch 2, Batch 6
  Loss: 100.7049
  LMSE: 86.9921, LMAPC: 0.0469
  AAM: 13.6659, NLL: 0.0000
Epoch 2, Batch 11
  Loss: 84.6595
  LMSE: 73.5715, LMAPC: 0.0419
  AAM: 11.0461, NLL: 0.0000

Epoch 2 Summ