In [1]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/contraeend/pytorch/default/1/contraeend_best.pth
/kaggle/input/callhome/labels/labels_14.rttm
/kaggle/input/callhome/labels/labels_135.rttm
/kaggle/input/callhome/labels/labels_25.rttm
/kaggle/input/callhome/labels/labels_28.rttm
/kaggle/input/callhome/labels/labels_48.rttm
/kaggle/input/callhome/labels/labels_91.rttm
/kaggle/input/callhome/labels/labels_32.rttm
/kaggle/input/callhome/labels/labels_136.rttm
/kaggle/input/callhome/labels/labels_18.rttm
/kaggle/input/callhome/labels/labels_80.rttm
/kaggle/input/callhome/labels/labels_54.rttm
/kaggle/input/callhome/labels/labels_9.rttm
/kaggle/input/callhome/labels/labels_38.rttm
/kaggle/input/callhome/labels/labels_98.rttm
/kaggle/input/callhome/labels/labels_127.rttm
/kaggle/input/callhome/labels/labels_51.rttm
/kaggle/input/callhome/labels/labels_29.rttm
/kaggle/input/callhome/labels/labels_131.rttm
/kaggle/input/callhome/labels/labels_24.rttm
/kaggle/input/callhome/labels/labels_58.rttm
/kaggle/input/callhome/labels/labe

# 0.Full Block

In [2]:
"""
ContraEEND - phase 2 v2: Contrastive Pretraining (FIXED VERSION)
Pretrain encoder on callhome for speaker discrimination
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Tuple, Dict, List
import random
from tqdm import tqdm
import math
import time
def set_seed(seed=42):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def setup_device():
    """Setup device with comprehensive CUDA checking for Kaggle"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print(f"‚úÖ CUDA is available!")
        print(f"üöÄ Using GPU: {torch.cuda.get_device_name(0)}")
        print(
            f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
        )
        print(f"üî• CUDA Version: {torch.version.cuda}")

        # Set memory allocation strategy for Kaggle
        torch.cuda.empty_cache()
        if hasattr(torch.cuda, "set_memory_fraction"):
            torch.cuda.set_memory_fraction(0.8)  # Use 80% of GPU memory

    else:
        device = torch.device("cpu")
        print("‚ö†Ô∏è  CUDA not available, using CPU")
        print("üí° Consider enabling GPU in Kaggle: Settings -> Accelerator -> GPU")

    return device

setup_device()

class AudioProcessor:
    """Unified audio processing pipeline"""
    def __init__(self, 
                 sample_rate: int = 16000,
                 n_fft: int = 400,  # 25ms at 16kHz
                 hop_length: int = 160,  # 10ms at 16kHz
                 n_mels: int = 83,
                 win_length: int = 400):
        self.sample_rate = sample_rate
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            n_mels=n_mels,
            f_min=20,
            f_max=sample_rate // 2
        )
    
    def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveform: (channels, time) or (time,)
        Returns:
            log_mel: (n_mels, frames)
        """
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        
        # Ensure mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # Compute mel spectrogram
        mel = self.mel_transform(waveform)
        
        # Log scaling with small epsilon for stability
        log_mel = torch.log(mel + 1e-6)
        
        return log_mel.squeeze(0)  # (n_mels, frames)


class SpecAugment(nn.Module):
    """SpecAugment for contrastive learning"""
    def __init__(self, freq_mask_param=27, time_mask_param=100, n_freq_masks=2, n_time_masks=2):
        super().__init__()
        self.freq_mask_param = freq_mask_param
        self.time_mask_param = time_mask_param
        self.n_freq_masks = n_freq_masks
        self.n_time_masks = n_time_masks
    
    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        """
        Args:
            mel: (n_mels, time)
        Returns:
            augmented: (n_mels, time)
        """
        mel = mel.clone()
        n_mels, n_frames = mel.shape
        
        # Frequency masking
        for _ in range(self.n_freq_masks):
            f = random.randint(0, self.freq_mask_param)
            f0 = random.randint(0, max(0, n_mels - f))
            mel[f0:f0+f, :] = 0
        
        # Time masking
        for _ in range(self.n_time_masks):
            t = random.randint(0, min(self.time_mask_param, max(1, n_frames - 1)))
            t0 = random.randint(0, max(0, n_frames - t))
            mel[:, t0:t0+t] = 0
        
        return mel

class AudioAugmentation:
    """Advanced audio augmentation for speaker discrimination"""
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
    
    def time_stretch(self, waveform: torch.Tensor, rate: float = None) -> torch.Tensor:
        """Time stretching (speed perturbation)"""
        if rate is None:
            rate = random.choice([0.9, 1.0, 1.1])
        
        if rate == 1.0:
            return waveform
        
        # Simple resampling-based time stretch
        original_length = waveform.shape[0]
        stretched_length = int(original_length / rate)
        
        if stretched_length > 0:
            stretched = F.interpolate(
                waveform.unsqueeze(0).unsqueeze(0),
                size=stretched_length,
                mode='linear',
                align_corners=False
            ).squeeze()
            
            # Crop or pad to original length
            if stretched.shape[0] > original_length:
                return stretched[:original_length]
            else:
                padding = original_length - stretched.shape[0]
                return F.pad(stretched, (0, padding))
        
        return waveform
    
    def pitch_shift(self, waveform: torch.Tensor, n_steps: int = None) -> torch.Tensor:
        """Pitch shifting"""
        if n_steps is None:
            n_steps = random.choice([-2, -1, 0, 1, 2])
        
        if n_steps == 0:
            return waveform
        
        # Approximate pitch shift via time stretch + resampling
        rate = 2 ** (n_steps / 12)
        shifted = self.time_stretch(waveform, rate)
        
        return shifted
    
    def add_noise(self, waveform: torch.Tensor, snr_db: float = None) -> torch.Tensor:
        """Add Gaussian noise"""
        if snr_db is None:
            snr_db = random.uniform(15, 30)  # SNR between 15-30 dB
        
        signal_power = waveform.pow(2).mean()
        snr_linear = 10 ** (snr_db / 10)
        noise_power = signal_power / snr_linear
        
        noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
        
        return waveform + noise
    
    def __call__(self, waveform: torch.Tensor, prob: float = 0.5) -> torch.Tensor:
        """Apply random augmentation"""
        if random.random() < prob:
            aug_type = random.choice(['time_stretch', 'pitch_shift', 'noise'])
            
            if aug_type == 'time_stretch':
                return self.time_stretch(waveform)
            elif aug_type == 'pitch_shift':
                return self.pitch_shift(waveform)
            elif aug_type == 'noise':
                return self.add_noise(waveform)
        
        return waveform

def build_audio_rttm_mapping(audio_dir: str, rttm_dir: str) -> List[Tuple[Path, Path]]:
    """
    Build mapping between audio files and RTTM files by index
    audio_0.wav <-> labels_0.rttm
    """
    audio_dir = Path(audio_dir)
    rttm_dir = Path(rttm_dir)
    
    # Get all audio and rttm files
    audio_files = sorted(audio_dir.glob('audio_*.wav'))
    rttm_files = sorted(rttm_dir.glob('labels_*.rttm'))
    
    # Also try other extensions
    if len(audio_files) == 0:
        for ext in ['.flac', '.sph', '.mp3']:
            audio_files = sorted(audio_dir.glob(f'audio_*{ext}'))
            if len(audio_files) > 0:
                break
    
    print(f"\nFound {len(audio_files)} audio files")
    print(f"Found {len(rttm_files)} RTTM files")
    
    if len(audio_files) == 0:
        raise ValueError(f"No audio files found in {audio_dir}")
    if len(rttm_files) == 0:
        raise ValueError(f"No RTTM files found in {rttm_dir}")
    
    # Extract indices and match
    audio_map = {}
    for audio_file in audio_files:
        # Extract index from "audio_123.wav"
        try:
            idx = int(audio_file.stem.split('_')[1])
            audio_map[idx] = audio_file
        except (IndexError, ValueError):
            print(f"Warning: Cannot parse index from {audio_file.name}")
    
    rttm_map = {}
    for rttm_file in rttm_files:
        # Extract index from "labels_123.rttm"
        try:
            idx = int(rttm_file.stem.split('_')[1])
            rttm_map[idx] = rttm_file
        except (IndexError, ValueError):
            print(f"Warning: Cannot parse index from {rttm_file.name}")
    
    # Match by index
    pairs = []
    for idx in sorted(audio_map.keys()):
        if idx in rttm_map:
            pairs.append((audio_map[idx], rttm_map[idx]))
        else:
            print(f"Warning: No RTTM file for audio index {idx}")
    
    print(f"Matched {len(pairs)} audio-RTTM pairs")
    
    if len(pairs) == 0:
        raise ValueError("No matching audio-RTTM pairs found")
    
    # Show examples
    print(f"\nExample pairs:")
    for audio_file, rttm_file in pairs[:3]:
        print(f"  {audio_file.name} <-> {rttm_file.name}")
    
    return pairs


class callhomeContrastiveDataset(Dataset):
    """
    Callhome dataset for contrastive learning
    Parses RTTM files to extract speaker segments
    Handles indexed naming: audio_0.wav <-> labels_0.rttm
    
    OPTIMIZATION: Caches audio files in memory for faster loading
    """
    def __init__(self, 
             audio_dir: str,
             rttm_dir: str,
             audio_processor: AudioProcessor,
             segment_length: float = 3.0,
             sample_rate: int = 16000,
             min_segment_length: float = 2.0,
             apply_augment: bool = True,
             cache_audio: bool = True,
             use_audio_augment: bool = True):  # NEW
    
        self.audio_dir = Path(audio_dir)
        self.rttm_dir = Path(rttm_dir)
        self.audio_processor = audio_processor
        self.segment_length = segment_length
        self.segment_samples = int(segment_length * sample_rate)
        self.min_samples = int(min_segment_length * sample_rate)
        self.sample_rate = sample_rate
        self.apply_augment = apply_augment
        self.cache_audio = cache_audio
        self.audio_cache = {}
        
        # Spec augmentation
        self.spec_augment = SpecAugment(
            freq_mask_param=15,
            time_mask_param=50,
            n_freq_masks=1,
            n_time_masks=1
        ) if apply_augment else None
        
        # Audio augmentation (NEW)
        self.use_audio_augment = use_audio_augment
        self.audio_augment = AudioAugmentation(sample_rate) if use_audio_augment else None
        
        # Build audio-RTTM pairs by index
        print("\n" + "="*60)
        print("Building audio-RTTM mapping...")
        print("="*60)
        self.audio_rttm_pairs = build_audio_rttm_mapping(audio_dir, rttm_dir)
        
        # Parse RTTM files to build speaker segments
        self.speaker_segments = self._parse_rttm_files()
        self.speakers = list(self.speaker_segments.keys())
        
        # Create flat list for indexing
        self.samples = []
        for spk_id, segments in self.speaker_segments.items():
            for seg in segments:
                if seg['duration'] >= min_segment_length:
                    self.samples.append((spk_id, seg))
        
        # Preload all audio files into memory if caching enabled
        if self.cache_audio:
            print("\nüì¶ Caching audio files in memory...")
            self._cache_all_audio()
        
        print(f"\n‚úì Loaded {len(self.speakers)} speakers, {len(self.samples)} segments")
        print(f"  Segment length: {segment_length}s, Min length: {min_segment_length}s")
        print(f"  SpecAugment: {'enabled' if apply_augment else 'disabled'}")
        print(f"  Audio augmentation: {'enabled' if use_audio_augment else 'disabled'}")
        print(f"  Audio caching: {'enabled' if cache_audio else 'disabled'}")
    
    def _parse_rttm_files(self) -> Dict[str, List[Dict]]:
        """
        Parse RTTM files to extract speaker segments
        Each audio file can have multiple speakers
        """
        speaker_segments = {}
        
        print(f"\nParsing {len(self.audio_rttm_pairs)} audio-RTTM pairs...")
        
        segments_found = 0
        
        for audio_path, rttm_path in tqdm(self.audio_rttm_pairs, desc="Parsing RTTM"):
            # Get file identifier for this pair
            file_idx = audio_path.stem.split('_')[1]  # e.g., "0" from "audio_0.wav"
            
            with open(rttm_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if not line or line.startswith('#'):
                        continue
                    
                    parts = line.split()
                    if len(parts) < 8 or parts[0] != 'SPEAKER':
                        continue
                    
                    # RTTM format: SPEAKER <file_id> 1 <start> <duration> <NA> <NA> <speaker_id> <NA>
                    file_id_in_rttm = parts[1]  # This might be different from our filename
                    start_time = float(parts[3])
                    duration = float(parts[4])
                    speaker_id = parts[7]
                    
                    # Build unique speaker ID: use file index + speaker ID
                    # This ensures speakers from different files are treated as different
                    full_speaker_id = f"file{file_idx}_spk{speaker_id}"
                    
                    # Create segment info
                    segment = {
                        'file_idx': file_idx,
                        'audio_path': audio_path,
                        'start': start_time,
                        'duration': duration,
                        'speaker_id': speaker_id
                    }
                    
                    if full_speaker_id not in speaker_segments:
                        speaker_segments[full_speaker_id] = []
                    
                    speaker_segments[full_speaker_id].append(segment)
                    segments_found += 1
        
        print(f"Total segments found: {segments_found}")
        
        # Filter speakers with at least 2 segments
        original_speaker_count = len(speaker_segments)
        speaker_segments = {
            spk: segs for spk, segs in speaker_segments.items() 
            if len(segs) >= 2
        }
        
        filtered_count = original_speaker_count - len(speaker_segments)
        print(f"Speakers after filtering (‚â•2 segments): {len(speaker_segments)}")
        if filtered_count > 0:
            print(f"  Removed {filtered_count} speakers with <2 segments")
        
        # Print statistics
        if speaker_segments:
            seg_counts = [len(segs) for segs in speaker_segments.values()]
            durations = [seg['duration'] for segs in speaker_segments.values() for seg in segs]
            print(f"\nSegment statistics:")
            print(f"  Per speaker - Min: {min(seg_counts)}, Max: {max(seg_counts)}, Avg: {np.mean(seg_counts):.1f}")
            print(f"  Duration - Min: {min(durations):.1f}s, Max: {max(durations):.1f}s, Avg: {np.mean(durations):.1f}s")
        
        return speaker_segments
    
    def _cache_all_audio(self):
        """Preload all audio files into memory"""
        unique_audio_files = set()
        for _, segment in self.samples:
            unique_audio_files.add(segment['audio_path'])
        
        print(f"Loading {len(unique_audio_files)} unique audio files...")
        for audio_path in tqdm(unique_audio_files, desc="Caching audio"):
            try:
                waveform, sr = torchaudio.load(str(audio_path))
                
                # Resample if needed
                if sr != self.sample_rate:
                    resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                    waveform = resampler(waveform)
                
                # Convert to mono
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                
                self.audio_cache[str(audio_path)] = waveform.squeeze(0)
            except Exception as e:
                print(f"\nError caching {audio_path}: {e}")
        
        print(f"‚úì Cached {len(self.audio_cache)} audio files")
    
    def _load_audio_segment(self, segment_info: Dict) -> torch.Tensor:
        """Load audio segment from file or cache based on RTTM timing"""
        audio_path = segment_info['audio_path']
        start_time = segment_info['start']
        duration = segment_info['duration']
        
        try:
            # Load from cache or file
            if self.cache_audio and str(audio_path) in self.audio_cache:
                waveform = self.audio_cache[str(audio_path)]
            else:
                waveform, sr = torchaudio.load(str(audio_path))
                
                # Resample if needed
                if sr != self.sample_rate:
                    resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                    waveform = resampler(waveform)
                
                # Convert to mono
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                
                waveform = waveform.squeeze(0)
            
            # Extract segment based on RTTM timing
            start_sample = int(start_time * self.sample_rate)
            duration_samples = int(duration * self.sample_rate)
            end_sample = start_sample + duration_samples
            
            # Clip to valid range
            start_sample = max(0, start_sample)
            end_sample = min(len(waveform), end_sample)
            
            segment = waveform[start_sample:end_sample]
            
            # Crop or pad to target length
            segment_len = len(segment)
            
            if segment_len < self.segment_samples:
                # Pad if too short
                padding = self.segment_samples - segment_len
                segment = F.pad(segment.unsqueeze(0), (0, padding)).squeeze(0)
            elif segment_len > self.segment_samples:
                # Random crop if too long
                max_start = segment_len - self.segment_samples
                crop_start = random.randint(0, max_start)
                segment = segment[crop_start:crop_start + self.segment_samples]
            
            return segment
        
        except Exception as e:
            print(f"\nError loading {audio_path}: {e}")
            # Return silence if loading fails
            return torch.zeros(self.segment_samples)
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Returns:
            anchor: (n_mels, frames)
            positive: (n_mels, frames) - different segment from same speaker
            speaker_id: int
        """
        speaker_id, anchor_segment = self.samples[idx]
        
        # Load anchor
        anchor_wav = self._load_audio_segment(anchor_segment)
        
        # Apply audio augmentation to anchor (NEW)
        if self.use_audio_augment and self.audio_augment is not None:
            anchor_wav = self.audio_augment(anchor_wav, prob=0.5)
        
        anchor_mel = self.audio_processor(anchor_wav)
        
        # Apply spec augmentation
        if self.apply_augment and self.spec_augment is not None:
            anchor_mel = self.spec_augment(anchor_mel)
        
        # Load positive (different segment from same speaker)
        positive_segment = random.choice(self.speaker_segments[speaker_id])
        
        # Ensure it's different from anchor if possible
        if len(self.speaker_segments[speaker_id]) > 1:
            max_attempts = 10
            for _ in range(max_attempts):
                positive_segment = random.choice(self.speaker_segments[speaker_id])
                if positive_segment['start'] != anchor_segment['start']:
                    break
        
        positive_wav = self._load_audio_segment(positive_segment)
        
        # Apply DIFFERENT audio augmentation to positive (NEW)
        if self.use_audio_augment and self.audio_augment is not None:
            positive_wav = self.audio_augment(positive_wav, prob=0.5)
        
        positive_mel = self.audio_processor(positive_wav)
        
        # Apply different spec augmentation to positive
        if self.apply_augment and self.spec_augment is not None:
            positive_mel = self.spec_augment(positive_mel)
        
        # Convert speaker_id to numeric
        speaker_idx = self.speakers.index(speaker_id)
        
        return anchor_mel, positive_mel, speaker_idx


class ConformerBlock(nn.Module):
    """Single Conformer block with multi-head attention and convolution"""
    def __init__(self, d_model: int, n_heads: int, conv_kernel: int = 31, dropout: float = 0.1):
        super().__init__()
        
        # Feed-forward module 1
        self.ff1 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        # Multi-head self-attention
        self.norm_attn = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.dropout_attn = nn.Dropout(dropout)
        
        # Convolution module
        self.norm_conv = nn.LayerNorm(d_model)
        self.conv = nn.Sequential(
            nn.Conv1d(d_model, d_model * 2, 1),
            nn.GLU(dim=1),
            nn.Conv1d(d_model, d_model, conv_kernel, padding=conv_kernel//2, groups=d_model),
            nn.BatchNorm1d(d_model),
            nn.SiLU(),
            nn.Conv1d(d_model, d_model, 1),
            nn.Dropout(dropout)
        )
        
        # Feed-forward module 2
        self.ff2 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        self.norm_out = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, time, d_model)
        Returns:
            output: (batch, time, d_model)
        """
        # FF1
        x = x + 0.5 * self.ff1(x)
        
        # Attention
        x_norm = self.norm_attn(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + self.dropout_attn(attn_out)
        
        # Convolution
        x_norm = self.norm_conv(x)
        x_conv = x_norm.transpose(1, 2)  # (batch, d_model, time)
        x_conv = self.conv(x_conv)
        x = x + x_conv.transpose(1, 2)
        
        # FF2
        x = x + 0.5 * self.ff2(x)
        
        return self.norm_out(x)


class ConformerEncoder(nn.Module):
    """Conformer encoder for audio processing"""
    def __init__(self, 
                 input_dim: int = 83,
                 d_model: int = 256,
                 n_layers: int = 8,
                 n_heads: int = 4,
                 conv_kernel: int = 31,
                 dropout: float = 0.1,
                 subsampling_factor: int = 4):
        super().__init__()
        
        # Subsampling layer (reduce frame rate by 4x)
        self.subsampling = nn.Sequential(
            nn.Conv1d(input_dim, d_model, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, dropout)
        
        # Conformer blocks
        self.blocks = nn.ModuleList([
            ConformerBlock(d_model, n_heads, conv_kernel, dropout)
            for _ in range(n_layers)
        ])
        
        self.d_model = d_model
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, n_mels, time)
        Returns:
            output: (batch, time//4, d_model)
        """
        # Subsampling
        x = self.subsampling(x)  # (batch, d_model, time//4)
        x = x.transpose(1, 2)  # (batch, time//4, d_model)
        
        # Positional encoding
        x = self.pos_encoding(x)
        
        # Conformer blocks
        for block in self.blocks:
            x = block(x)
        
        return x


class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding"""
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 10000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class ProjectionHead(nn.Module):
    """Projection head for contrastive learning"""
    def __init__(self, input_dim: int, hidden_dim: int = 256, output_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class ContraEENDEncoder(nn.Module):
    """Full encoder with projection head"""
    def __init__(self, 
                 input_dim: int = 83,
                 d_model: int = 256,
                 n_layers: int = 8,
                 n_heads: int = 4,
                 projection_dim: int = 128):
        super().__init__()
        
        self.encoder = ConformerEncoder(
            input_dim=input_dim,
            d_model=d_model,
            n_layers=n_layers,
            n_heads=n_heads
        )
        
        # Temporal pooling for utterance-level representation
        self.temporal_pool = nn.AdaptiveAvgPool1d(1)
        
        # Projection head
        self.projection = ProjectionHead(d_model, d_model, projection_dim)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, n_mels, time)
        Returns:
            embeddings: (batch, d_model) - frame-level pooled
            projections: (batch, projection_dim) - for contrastive loss
        """
        # Encode
        encoded = self.encoder(x)  # (batch, time//4, d_model)
        
        # Temporal pooling
        pooled = self.temporal_pool(encoded.transpose(1, 2)).squeeze(-1)  # (batch, d_model)
        
        # Project
        projected = self.projection(pooled)  # (batch, projection_dim)
        
        # L2 normalize projections
        projected = F.normalize(projected, p=2, dim=1)
        
        return pooled, projected


class SupConLoss(nn.Module):
    """Supervised Contrastive Loss with hard negative mining"""
    def __init__(self, temperature: float = 0.2, base_temperature: float = 0.2):
        super().__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
    
    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        device = features.device
        batch_size = features.shape[0]
        
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        
        # Compute similarity matrix
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature
        )
        
        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        
        # Remove self-contrast
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        
        # HARD NEGATIVE MINING: Weight harder negatives more
        exp_logits = torch.exp(logits) * logits_mask
        
        # Hard negative mask (only keep negatives, exclude self and positives)
        hard_negative_mask = (1 - mask) * logits_mask
        
        # Weighted by difficulty (higher similarity = harder negative)
        hard_weights = torch.exp(logits) * hard_negative_mask
        hard_weights = hard_weights / (hard_weights.sum(1, keepdim=True) + 1e-8)
        
        # Reweight denominator with hard negatives
        weighted_exp_logits = exp_logits * (1 + hard_weights)
        
        log_prob = logits - torch.log(weighted_exp_logits.sum(1, keepdim=True))
        
        # Compute mean over positives
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
        
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.mean()
        
        return loss

import csv
from datetime import datetime

class MetricsLogger:
    """CSV logger for tracking training metrics"""
    def __init__(self, log_dir: str = './logs', experiment_name: str = 'contraeend'):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.csv_path = self.log_dir / f'{experiment_name}_{timestamp}.csv'
        
        self.columns = [
            'timestamp', 'epoch', 'phase', 'learning_rate',
            'train_loss', 'val_loss',
            'pos_sim_mean', 'neg_sim_mean', 'separation',
            'accuracy', 'adjusted_rand_score', 'der', 'transition_rate',
            'avg_segment_duration', 'num_segments', 'temporal_stability',
            'per_speaker_accuracy_mean', 'per_speaker_accuracy_std',
            'num_speakers_detected', 'num_speakers_true',
            'has_temporal_modeling', 'oracle_speakers',
            'filename', 'duration',
            'batch_size', 'd_model', 'n_layers', 'temperature',
        ]
        
        with open(self.csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.columns)
            writer.writeheader()
        
        print(f"üìä Metrics logger initialized: {self.csv_path}")
    
    def log_epoch(self, epoch, phase, train_loss, val_loss, learning_rate,
                  embedding_metrics=None, config=None):
        """Log metrics for one epoch"""
        row = {
            'timestamp': datetime.now().isoformat(),
            'epoch': epoch,
            'phase': phase,
            'learning_rate': learning_rate,
            'train_loss': train_loss,
            'val_loss': val_loss,
        }
        
        if embedding_metrics:
            row.update({
                'pos_sim_mean': embedding_metrics.get('pos_sim_mean', ''),
                'neg_sim_mean': embedding_metrics.get('neg_sim_mean', ''),
                'separation': embedding_metrics.get('separation', ''),
            })
        
        if config:
            row.update({
                'batch_size': config.get('batch_size', ''),
                'd_model': config.get('d_model', ''),
                'n_layers': config.get('n_layers', ''),
                'temperature': config.get('temperature', ''),
            })
        
        with open(self.csv_path, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.columns)
            writer.writerow(row)


class ContrastiveTrainer:
    """Trainer for phase 2 v2: Contrastive Pretraining"""
    def __init__(self,
                 model: nn.Module,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 num_epochs: int,
                 device: str = 'cuda',
                 learning_rate: float = 1e-3,
                 weight_decay: float = 1e-4,
                 temperature: float = 0.2,
                 accumulation_steps: int = 4, 
                 log_dir: str = './logs',
                 config: dict = None):
        
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        
        # Loss
        self.criterion = SupConLoss(temperature=temperature)
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=(0.9, 0.98)
        )
        
        # Warmup + Cosine scheduler
        self.warmup_epochs = 5
        self.total_steps = len(train_loader) * num_epochs
        self.warmup_steps = len(train_loader) * self.warmup_epochs
        self.current_step = 0
        
        def lr_lambda(current_step):
            if current_step < self.warmup_steps:
                # Linear warmup
                return float(current_step) / float(max(1, self.warmup_steps))
            # Cosine decay after warmup
            progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
        
        from torch.optim.lr_scheduler import LambdaLR
        self.scheduler = LambdaLR(self.optimizer, lr_lambda)
        
        # Mixed precision
        self.scaler = torch.cuda.amp.GradScaler() if device == 'cuda' else None
        self.use_amp = device == 'cuda'
        
        # Early stopping
        self.best_val_loss = float('inf')
        self.patience = 10
        self.patience_counter = 0
        self.min_delta = 1e-4
        self.accumulation_steps = accumulation_steps
        self.logger = MetricsLogger(log_dir=log_dir, experiment_name='contraeend')
        self.config = config
        print(f"Effective batch size: {train_loader.batch_size * 2 * accumulation_steps}")
    
    
    def train_epoch(self, epoch: int) -> float:
        """Train one epoch with gradient accumulation"""
        self.model.train()
        # Check GPU usage on first batch
        if epoch == 1:
            print("\nüîç Checking GPU usage on first batch...")
        total_loss = 0
        accumulated_loss = 0

        # Timing
        data_time = 0
        forward_time = 0
        backward_time = 0
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch} [Train]')
        for batch_idx, (anchor, positive, labels) in enumerate(pbar):
            t0 = time.time()
            
            anchor = anchor.to(self.device)
            positive = positive.to(self.device)
            labels = labels.to(self.device)
            combined = torch.cat([anchor, positive], dim=0)
            combined_labels = torch.cat([labels, labels], dim=0)
            
            data_time += time.time() - t0
            t1 = time.time()
            
            # Forward
            with torch.cuda.amp.autocast(enabled=self.use_amp):
                _, projections = self.model(combined)
                loss = self.criterion(projections, combined_labels)
                loss = loss / self.accumulation_steps
            
            forward_time += time.time() - t1
            t2 = time.time()
                
            # Backward
            if self.use_amp:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()
            
            accumulated_loss += loss.item()
            backward_time += time.time() - t2
            # Show timing on first epoch
            if epoch == 1 and batch_idx == 10:
                print(f"\n‚è±Ô∏è Timing breakdown (first 10 batches):")
                print(f"  Data loading: {data_time/10:.2f}s per batch")
                print(f"  Forward pass: {forward_time/10:.2f}s per batch")
                print(f"  Backward pass: {backward_time/10:.2f}s per batch")
            
            # Update weights every accumulation_steps
            if (batch_idx + 1) % self.accumulation_steps == 0:
                if self.use_amp:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                    self.optimizer.step()
                
                self.optimizer.zero_grad()
                
                # Step scheduler
                self.current_step += 1
                self.scheduler.step()
                
                total_loss += accumulated_loss
                current_lr = self.optimizer.param_groups[0]['lr']
                pbar.set_postfix({
                    'loss': f'{accumulated_loss:.4f}',
                    'lr': f'{current_lr:.6f}'
                })
                accumulated_loss = 0
        
        avg_loss = total_loss / (len(self.train_loader) // self.accumulation_steps)
        return avg_loss
    
    def validate(self, epoch: int) -> float:
        """Validate"""
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc=f'Epoch {epoch} [Val]')
            for anchor, positive, labels in pbar:
                anchor = anchor.to(self.device)
                positive = positive.to(self.device)
                labels = labels.to(self.device)
                
                combined = torch.cat([anchor, positive], dim=0)
                combined_labels = torch.cat([labels, labels], dim=0)
                
                with torch.cuda.amp.autocast(enabled=self.use_amp):
                    _, projections = self.model(combined)
                    loss = self.criterion(projections, combined_labels)
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_loss = total_loss / len(self.val_loader)
        return avg_loss
    
    def compute_embedding_metrics(self, loader: DataLoader) -> Dict[str, float]:
        """Compute cosine similarity metrics for validation"""
        self.model.eval()
        embeddings_list = []
        labels_list = []
        
        with torch.no_grad():
            for batch_idx, (anchor, positive, labels) in enumerate(loader):
                anchor = anchor.to(self.device)
                with torch.cuda.amp.autocast(enabled=self.use_amp):
                    _, anchor_proj = self.model(anchor)
                embeddings_list.append(anchor_proj.cpu())
                labels_list.append(labels)
                
                if batch_idx >= 20:  # Limit samples for speed
                    break
        
        if len(embeddings_list) == 0:
            return {'pos_sim_mean': 0.0, 'neg_sim_mean': 0.0, 'separation': 0.0}
        
        embeddings = torch.cat(embeddings_list, dim=0)  # (N, projection_dim)
        labels = torch.cat(labels_list, dim=0)  # (N,)
        
        # Compute cosine similarity matrix
        sim_matrix = torch.matmul(embeddings, embeddings.T)  # Already normalized
        
        # Same speaker pairs (positive)
        same_speaker_mask = (labels.unsqueeze(0) == labels.unsqueeze(1))
        same_speaker_mask.fill_diagonal_(False)  # Exclude self
        same_speaker_sims = sim_matrix[same_speaker_mask]
        
        # Different speaker pairs (negative)
        diff_speaker_mask = ~same_speaker_mask
        diff_speaker_mask.fill_diagonal_(False)
        diff_speaker_sims = sim_matrix[diff_speaker_mask]
        
        metrics = {
            'pos_sim_mean': same_speaker_sims.mean().item() if len(same_speaker_sims) > 0 else 0.0,
            'neg_sim_mean': diff_speaker_sims.mean().item() if len(diff_speaker_sims) > 0 else 0.0,
            'separation': (same_speaker_sims.mean() - diff_speaker_sims.mean()).item() if len(same_speaker_sims) > 0 else 0.0
        }
        
        return metrics
    
    def save_checkpoint(self, epoch: int, loss: float, filepath: str):
        """Save training checkpoint"""
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
            'best_val_loss': self.best_val_loss,
            'current_step': self.current_step,
            'patience_counter': self.patience_counter
        }, filepath)
        print(f"Checkpoint saved: {filepath}")
    
    def load_checkpoint(self, filepath: str) -> int:
        """Resume from checkpoint"""
        if not os.path.exists(filepath):
            print(f"No checkpoint found at {filepath}")
            return 0
        
        checkpoint = torch.load(filepath, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        self.current_step = checkpoint.get('current_step', 0)
        
        # RESET patience for enhancement training
        self.patience_counter = 0  # Force reset
        
        epoch = checkpoint['epoch']
        print(f"‚úì Resumed from epoch {epoch}, best val loss: {self.best_val_loss:.4f}")
        print(f"‚ö†Ô∏è  Patience counter reset to 0 for enhancement training")
        return epoch
        
    def train(self, num_epochs: int, checkpoint_dir: str = './checkpoints', start_epoch: int = 0):
        """Full training loop with early stopping"""
        Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
        
        for epoch in range(start_epoch + 1, num_epochs + 1):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch}/{num_epochs}")
            print(f"{'='*60}")
            
            # Train
            train_loss = self.train_epoch(epoch)
            print(f"Train Loss: {train_loss:.4f}")
            
            # Validate
            val_loss = self.validate(epoch)
            print(f"Val Loss: {val_loss:.4f}")
            
            # Compute embedding metrics
            metrics = self.compute_embedding_metrics(self.val_loader)
            print(f"Pos Sim: {metrics['pos_sim_mean']:.4f} | Neg Sim: {metrics['neg_sim_mean']:.4f} | Separation: {metrics['separation']:.4f}")
            
            current_lr = self.optimizer.param_groups[0]['lr']
            print(f"Learning Rate: {current_lr:.6f}")

            self.logger.log_epoch(
                epoch=epoch, phase='pretrain', train_loss=train_loss,
                val_loss=val_loss, learning_rate=current_lr,
                embedding_metrics=metrics, config=self.config
            )
            
            # Save checkpoint every epoch
            checkpoint_path = Path(checkpoint_dir) / f'contraeend_epoch_{epoch}.pth'
            self.save_checkpoint(epoch, val_loss, str(checkpoint_path))
            
            # Early stopping check
            if val_loss < self.best_val_loss - self.min_delta:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                best_path = Path(checkpoint_dir) / 'contraeend_best.pth'
                self.save_checkpoint(epoch, val_loss, str(best_path))
                print(f"‚úì New best model saved! Val Loss: {val_loss:.4f}")
            else:
                self.patience_counter += 1
                print(f"Patience: {self.patience_counter}/{self.patience}")
                
                if self.patience_counter >= self.patience:
                    print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch} epochs")
                    break
        
        print(f"\nTraining completed! Best validation loss: {self.best_val_loss:.4f}")

def main():
    """Main training function for phase 2 v2"""
    
    # Set random seed for reproducibility
    set_seed(42)

    # Configuration
    config = {
        'audio_dir': '/kaggle/input/callhome/audio',
        'rttm_dir': '/kaggle/input/callhome/labels',
        'batch_size': 128,  # INCREASED from 64
        'num_epochs': 200,   # Reduced since we have better augmentation
        'learning_rate': 5e-5,
        'weight_decay': 1e-4,
        'temperature': 0.15,  # LOWERED from 0.2 (more strict)
        'segment_length': 3.0,
        'd_model': 128,
        'n_layers': 6,
        'n_heads': 4,
        'projection_dim': 64,
        'num_workers': 4,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'resume_from': '/kaggle/input/contraeend/pytorch/default/1/contraeend_best.pth',  # Resume from best model
        'cache_audio': True,
        'use_audio_augment': True,  # NEW: Enable audio augmentation
    }
    
    print("="*60)
    print("ContraEEND - phase 2 v2: Contrastive Pretraining (FIXED)")
    print("="*60)
    print(f"Device: {config['device']}")
    print(f"Audio dir: {config['audio_dir']}")
    print(f"RTTM dir: {config['rttm_dir']}")
    print(f"Batch Size: {config['batch_size']}")
    print(f"Temperature: {config['temperature']}")
    print(f"Model: Conformer ({config['n_layers']} layers, {config['d_model']} dim)")
    print("="*60)
    
    # Audio processor
    audio_processor = AudioProcessor()
    
    # Dataset with RTTM parsing
    print("\nLoading datasets...")
    full_dataset = callhomeContrastiveDataset(
        audio_dir=config['audio_dir'],
        rttm_dir=config['rttm_dir'],
        audio_processor=audio_processor,
        segment_length=config['segment_length'],
        apply_augment=True,
        cache_audio=config['cache_audio']  # NEW: Pass cache parameter
    )
        
    # Check if dataset is empty
    if len(full_dataset) == 0:
        print("\n‚ùå ERROR: No valid samples found!")
        print("\nPossible issues:")
        print("1. Audio files don't match RTTM file IDs")
        print("2. All segments are too short (< 2.0s)")
        print("3. No speakers have ‚â•2 segments")
        return
    
    # Dataset diagnostics
    print("\n" + "="*60)
    print("üîç DATASET DIAGNOSTICS")
    print("="*60)
    print(f"Total samples: {len(full_dataset)}")
    print(f"Total speakers: {len(full_dataset.speakers)}")
    
    if len(full_dataset.speakers) > 0:
        speaker_counts = {spk: len(segs) for spk, segs in full_dataset.speaker_segments.items()}
        segments_per_speaker = list(speaker_counts.values())
        
        print(f"Avg segments per speaker: {np.mean(segments_per_speaker):.1f}")
        print(f"Min segments per speaker: {min(segments_per_speaker)}")
        print(f"Max segments per speaker: {max(segments_per_speaker)}")
        
        print(f"\nSample speakers: {full_dataset.speakers[:5]}")
    
    # Check if dataset is sufficient
    if len(full_dataset) < 200:
        print("\n‚ö†Ô∏è  WARNING: Dataset is small!")
        print(f"   Current: {len(full_dataset)} samples")
        print(f"   This may limit model performance")
    
    if len(full_dataset.speakers) < 20:
        print("\n‚ö†Ô∏è  WARNING: Few speakers!")
        print(f"   Current: {len(full_dataset.speakers)} speakers")
        print(f"   Contrastive learning works best with 100+ speakers")
    
    print("="*60 + "\n")
    
    # Split for validation (90/10)
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    
    # Adjust batch size and accumulation if needed
    if len(train_dataset) < config['batch_size'] * 4:
        print("\n‚ö†Ô∏è  Small dataset: adjusting batch size and disabling accumulation")
        config['batch_size'] = min(16, len(train_dataset) // 2)
        accumulation_steps = 1
    else:
        accumulation_steps = 4
    
    print(f"Batch size: {config['batch_size']}")
    print(f"Accumulation steps: {accumulation_steps}")
    print(f"Effective batch size: {config['batch_size'] * accumulation_steps * 2}")
    
    # Data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True if config['device'] == 'cuda' else False,
        drop_last=True if len(train_dataset) > config['batch_size'] else False,
        persistent_workers=True if config['num_workers'] > 0 else False  # NEW
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True if config['device'] == 'cuda' else False,
        persistent_workers=True if config['num_workers'] > 0 else False  # NEW
    )
        
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    
    # Model
    print("\nInitializing model...")
    model = ContraEENDEncoder(
        input_dim=83,
        d_model=config['d_model'],
        n_layers=config['n_layers'],
        n_heads=config['n_heads'],
        projection_dim=config['projection_dim']
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Trainer
    print("\nInitializing trainer...")
    trainer = ContrastiveTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=config['num_epochs'],
        device=config['device'],
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        temperature=config['temperature'],
        accumulation_steps=8 # INCREASED from 4: Effective batch = 128*8*2 = 2048
    )
    
    # Resume from checkpoint if specified
    start_epoch = 0
    if config['resume_from'] is not None and os.path.exists(config['resume_from']):
        print(f"\nResuming from checkpoint: {config['resume_from']}")
        start_epoch = trainer.load_checkpoint(config['resume_from'])
        # FORCE RESET PATIENCE
        trainer.patience_counter = 0
        trainer.patience = 15  # more patience
        print(f"‚ö†Ô∏è  Patience reset to 0/15 for enhancement training")
    
    
    # Train
    print("\nStarting training...")
    print("Features enabled:")
    print("="*60)
    
    trainer.train(
        num_epochs=config['num_epochs'],
        start_epoch=start_epoch
    )
    
    print("\n" + "="*60)
    print("== Training Complete! ==")
    print(f"Best Validation Loss: {trainer.best_val_loss:.4f}")
    print("="*60)


if __name__ == '__main__':
    main()

In [3]:
"""
ContraEEND - phase 2 v2: Contrastive Pretraining (FIXED VERSION)
Pretrain encoder on callhome for speaker discrimination
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Tuple, Dict, List
import random
from tqdm import tqdm
import math
import time
def set_seed(seed=42):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def setup_device():
    """Setup device with comprehensive CUDA checking for Kaggle"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print(f"‚úÖ CUDA is available!")
        print(f"üöÄ Using GPU: {torch.cuda.get_device_name(0)}")
        print(
            f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
        )
        print(f"üî• CUDA Version: {torch.version.cuda}")

        # Set memory allocation strategy for Kaggle
        torch.cuda.empty_cache()
        if hasattr(torch.cuda, "set_memory_fraction"):
            torch.cuda.set_memory_fraction(0.8)  # Use 80% of GPU memory

    else:
        device = torch.device("cpu")
        print("‚ö†Ô∏è  CUDA not available, using CPU")
        print("üí° Consider enabling GPU in Kaggle: Settings -> Accelerator -> GPU")

    return device

setup_device()

‚úÖ CUDA is available!
üöÄ Using GPU: Tesla P100-PCIE-16GB
üíæ GPU Memory: 17.1 GB
üî• CUDA Version: 12.4


device(type='cuda', index=0)

# 1. AUDIO PROCESSING

In [4]:
class AudioProcessor:
    """Unified audio processing pipeline"""
    def __init__(self, 
                 sample_rate: int = 16000,
                 n_fft: int = 400,  # 25ms at 16kHz
                 hop_length: int = 160,  # 10ms at 16kHz
                 n_mels: int = 83,
                 win_length: int = 400):
        self.sample_rate = sample_rate
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            n_mels=n_mels,
            f_min=20,
            f_max=sample_rate // 2
        )
    
    def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveform: (channels, time) or (time,)
        Returns:
            log_mel: (n_mels, frames)
        """
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        
        # Ensure mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # Compute mel spectrogram
        mel = self.mel_transform(waveform)
        
        # Log scaling with small epsilon for stability
        log_mel = torch.log(mel + 1e-6)
        
        return log_mel.squeeze(0)  # (n_mels, frames)


class SpecAugment(nn.Module):
    """SpecAugment for contrastive learning"""
    def __init__(self, freq_mask_param=27, time_mask_param=100, n_freq_masks=2, n_time_masks=2):
        super().__init__()
        self.freq_mask_param = freq_mask_param
        self.time_mask_param = time_mask_param
        self.n_freq_masks = n_freq_masks
        self.n_time_masks = n_time_masks
    
    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        """
        Args:
            mel: (n_mels, time)
        Returns:
            augmented: (n_mels, time)
        """
        mel = mel.clone()
        n_mels, n_frames = mel.shape
        
        # Frequency masking
        for _ in range(self.n_freq_masks):
            f = random.randint(0, self.freq_mask_param)
            f0 = random.randint(0, max(0, n_mels - f))
            mel[f0:f0+f, :] = 0
        
        # Time masking
        for _ in range(self.n_time_masks):
            t = random.randint(0, min(self.time_mask_param, max(1, n_frames - 1)))
            t0 = random.randint(0, max(0, n_frames - t))
            mel[:, t0:t0+t] = 0
        
        return mel

class AudioAugmentation:
    """Advanced audio augmentation for speaker discrimination"""
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
    
    def time_stretch(self, waveform: torch.Tensor, rate: float = None) -> torch.Tensor:
        """Time stretching (speed perturbation)"""
        if rate is None:
            rate = random.choice([0.9, 1.0, 1.1])
        
        if rate == 1.0:
            return waveform
        
        # Simple resampling-based time stretch
        original_length = waveform.shape[0]
        stretched_length = int(original_length / rate)
        
        if stretched_length > 0:
            stretched = F.interpolate(
                waveform.unsqueeze(0).unsqueeze(0),
                size=stretched_length,
                mode='linear',
                align_corners=False
            ).squeeze()
            
            # Crop or pad to original length
            if stretched.shape[0] > original_length:
                return stretched[:original_length]
            else:
                padding = original_length - stretched.shape[0]
                return F.pad(stretched, (0, padding))
        
        return waveform
    
    def pitch_shift(self, waveform: torch.Tensor, n_steps: int = None) -> torch.Tensor:
        """Pitch shifting"""
        if n_steps is None:
            n_steps = random.choice([-2, -1, 0, 1, 2])
        
        if n_steps == 0:
            return waveform
        
        # Approximate pitch shift via time stretch + resampling
        rate = 2 ** (n_steps / 12)
        shifted = self.time_stretch(waveform, rate)
        
        return shifted
    
    def add_noise(self, waveform: torch.Tensor, snr_db: float = None) -> torch.Tensor:
        """Add Gaussian noise"""
        if snr_db is None:
            snr_db = random.uniform(15, 30)  # SNR between 15-30 dB
        
        signal_power = waveform.pow(2).mean()
        snr_linear = 10 ** (snr_db / 10)
        noise_power = signal_power / snr_linear
        
        noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
        
        return waveform + noise
    
    def __call__(self, waveform: torch.Tensor, prob: float = 0.5) -> torch.Tensor:
        """Apply random augmentation"""
        if random.random() < prob:
            aug_type = random.choice(['time_stretch', 'pitch_shift', 'noise'])
            
            if aug_type == 'time_stretch':
                return self.time_stretch(waveform)
            elif aug_type == 'pitch_shift':
                return self.pitch_shift(waveform)
            elif aug_type == 'noise':
                return self.add_noise(waveform)
        
        return waveform

# 2. DATASET

In [5]:
def build_audio_rttm_mapping(audio_dir: str, rttm_dir: str) -> List[Tuple[Path, Path]]:
    """
    Build mapping between audio files and RTTM files by index
    audio_0.wav <-> labels_0.rttm
    """
    audio_dir = Path(audio_dir)
    rttm_dir = Path(rttm_dir)
    
    # Get all audio and rttm files
    audio_files = sorted(audio_dir.glob('audio_*.wav'))
    rttm_files = sorted(rttm_dir.glob('labels_*.rttm'))
    
    # Also try other extensions
    if len(audio_files) == 0:
        for ext in ['.flac', '.sph', '.mp3']:
            audio_files = sorted(audio_dir.glob(f'audio_*{ext}'))
            if len(audio_files) > 0:
                break
    
    print(f"\nFound {len(audio_files)} audio files")
    print(f"Found {len(rttm_files)} RTTM files")
    
    if len(audio_files) == 0:
        raise ValueError(f"No audio files found in {audio_dir}")
    if len(rttm_files) == 0:
        raise ValueError(f"No RTTM files found in {rttm_dir}")
    
    # Extract indices and match
    audio_map = {}
    for audio_file in audio_files:
        # Extract index from "audio_123.wav"
        try:
            idx = int(audio_file.stem.split('_')[1])
            audio_map[idx] = audio_file
        except (IndexError, ValueError):
            print(f"Warning: Cannot parse index from {audio_file.name}")
    
    rttm_map = {}
    for rttm_file in rttm_files:
        # Extract index from "labels_123.rttm"
        try:
            idx = int(rttm_file.stem.split('_')[1])
            rttm_map[idx] = rttm_file
        except (IndexError, ValueError):
            print(f"Warning: Cannot parse index from {rttm_file.name}")
    
    # Match by index
    pairs = []
    for idx in sorted(audio_map.keys()):
        if idx in rttm_map:
            pairs.append((audio_map[idx], rttm_map[idx]))
        else:
            print(f"Warning: No RTTM file for audio index {idx}")
    
    print(f"Matched {len(pairs)} audio-RTTM pairs")
    
    if len(pairs) == 0:
        raise ValueError("No matching audio-RTTM pairs found")
    
    # Show examples
    print(f"\nExample pairs:")
    for audio_file, rttm_file in pairs[:3]:
        print(f"  {audio_file.name} <-> {rttm_file.name}")
    
    return pairs


class callhomeContrastiveDataset(Dataset):
    """
    Callhome dataset for contrastive learning
    Parses RTTM files to extract speaker segments
    Handles indexed naming: audio_0.wav <-> labels_0.rttm
    
    OPTIMIZATION: Caches audio files in memory for faster loading
    """
    def __init__(self, 
             audio_dir: str,
             rttm_dir: str,
             audio_processor: AudioProcessor,
             segment_length: float = 3.0,
             sample_rate: int = 16000,
             min_segment_length: float = 2.0,
             apply_augment: bool = True,
             cache_audio: bool = True,
             use_audio_augment: bool = True):  # NEW
    
        self.audio_dir = Path(audio_dir)
        self.rttm_dir = Path(rttm_dir)
        self.audio_processor = audio_processor
        self.segment_length = segment_length
        self.segment_samples = int(segment_length * sample_rate)
        self.min_samples = int(min_segment_length * sample_rate)
        self.sample_rate = sample_rate
        self.apply_augment = apply_augment
        self.cache_audio = cache_audio
        self.audio_cache = {}
        
        # Spec augmentation
        self.spec_augment = SpecAugment(
            freq_mask_param=15,
            time_mask_param=50,
            n_freq_masks=1,
            n_time_masks=1
        ) if apply_augment else None
        
        # Audio augmentation (NEW)
        self.use_audio_augment = use_audio_augment
        self.audio_augment = AudioAugmentation(sample_rate) if use_audio_augment else None
        
        # Build audio-RTTM pairs by index
        print("\n" + "="*60)
        print("Building audio-RTTM mapping...")
        print("="*60)
        self.audio_rttm_pairs = build_audio_rttm_mapping(audio_dir, rttm_dir)
        
        # Parse RTTM files to build speaker segments
        self.speaker_segments = self._parse_rttm_files()
        self.speakers = list(self.speaker_segments.keys())
        
        # Create flat list for indexing
        self.samples = []
        for spk_id, segments in self.speaker_segments.items():
            for seg in segments:
                if seg['duration'] >= min_segment_length:
                    self.samples.append((spk_id, seg))
        
        # Preload all audio files into memory if caching enabled
        if self.cache_audio:
            print("\nüì¶ Caching audio files in memory...")
            self._cache_all_audio()
        
        print(f"\n‚úì Loaded {len(self.speakers)} speakers, {len(self.samples)} segments")
        print(f"  Segment length: {segment_length}s, Min length: {min_segment_length}s")
        print(f"  SpecAugment: {'enabled' if apply_augment else 'disabled'}")
        print(f"  Audio augmentation: {'enabled' if use_audio_augment else 'disabled'}")
        print(f"  Audio caching: {'enabled' if cache_audio else 'disabled'}")
    
    def _parse_rttm_files(self) -> Dict[str, List[Dict]]:
        """
        Parse RTTM files to extract speaker segments
        Each audio file can have multiple speakers
        """
        speaker_segments = {}
        
        print(f"\nParsing {len(self.audio_rttm_pairs)} audio-RTTM pairs...")
        
        segments_found = 0
        
        for audio_path, rttm_path in tqdm(self.audio_rttm_pairs, desc="Parsing RTTM"):
            # Get file identifier for this pair
            file_idx = audio_path.stem.split('_')[1]  # e.g., "0" from "audio_0.wav"
            
            with open(rttm_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if not line or line.startswith('#'):
                        continue
                    
                    parts = line.split()
                    if len(parts) < 8 or parts[0] != 'SPEAKER':
                        continue
                    
                    # RTTM format: SPEAKER <file_id> 1 <start> <duration> <NA> <NA> <speaker_id> <NA>
                    file_id_in_rttm = parts[1]  # This might be different from our filename
                    start_time = float(parts[3])
                    duration = float(parts[4])
                    speaker_id = parts[7]
                    
                    # Build unique speaker ID: use file index + speaker ID
                    # This ensures speakers from different files are treated as different
                    full_speaker_id = f"file{file_idx}_spk{speaker_id}"
                    
                    # Create segment info
                    segment = {
                        'file_idx': file_idx,
                        'audio_path': audio_path,
                        'start': start_time,
                        'duration': duration,
                        'speaker_id': speaker_id
                    }
                    
                    if full_speaker_id not in speaker_segments:
                        speaker_segments[full_speaker_id] = []
                    
                    speaker_segments[full_speaker_id].append(segment)
                    segments_found += 1
        
        print(f"Total segments found: {segments_found}")
        
        # Filter speakers with at least 2 segments
        original_speaker_count = len(speaker_segments)
        speaker_segments = {
            spk: segs for spk, segs in speaker_segments.items() 
            if len(segs) >= 2
        }
        
        filtered_count = original_speaker_count - len(speaker_segments)
        print(f"Speakers after filtering (‚â•2 segments): {len(speaker_segments)}")
        if filtered_count > 0:
            print(f"  Removed {filtered_count} speakers with <2 segments")
        
        # Print statistics
        if speaker_segments:
            seg_counts = [len(segs) for segs in speaker_segments.values()]
            durations = [seg['duration'] for segs in speaker_segments.values() for seg in segs]
            print(f"\nSegment statistics:")
            print(f"  Per speaker - Min: {min(seg_counts)}, Max: {max(seg_counts)}, Avg: {np.mean(seg_counts):.1f}")
            print(f"  Duration - Min: {min(durations):.1f}s, Max: {max(durations):.1f}s, Avg: {np.mean(durations):.1f}s")
        
        return speaker_segments
    
    def _cache_all_audio(self):
        """Preload all audio files into memory"""
        unique_audio_files = set()
        for _, segment in self.samples:
            unique_audio_files.add(segment['audio_path'])
        
        print(f"Loading {len(unique_audio_files)} unique audio files...")
        for audio_path in tqdm(unique_audio_files, desc="Caching audio"):
            try:
                waveform, sr = torchaudio.load(str(audio_path))
                
                # Resample if needed
                if sr != self.sample_rate:
                    resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                    waveform = resampler(waveform)
                
                # Convert to mono
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                
                self.audio_cache[str(audio_path)] = waveform.squeeze(0)
            except Exception as e:
                print(f"\nError caching {audio_path}: {e}")
        
        print(f"‚úì Cached {len(self.audio_cache)} audio files")
    
    def _load_audio_segment(self, segment_info: Dict) -> torch.Tensor:
        """Load audio segment from file or cache based on RTTM timing"""
        audio_path = segment_info['audio_path']
        start_time = segment_info['start']
        duration = segment_info['duration']
        
        try:
            # Load from cache or file
            if self.cache_audio and str(audio_path) in self.audio_cache:
                waveform = self.audio_cache[str(audio_path)]
            else:
                waveform, sr = torchaudio.load(str(audio_path))
                
                # Resample if needed
                if sr != self.sample_rate:
                    resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                    waveform = resampler(waveform)
                
                # Convert to mono
                if waveform.shape[0] > 1:
                    waveform = waveform.mean(dim=0, keepdim=True)
                
                waveform = waveform.squeeze(0)
            
            # Extract segment based on RTTM timing
            start_sample = int(start_time * self.sample_rate)
            duration_samples = int(duration * self.sample_rate)
            end_sample = start_sample + duration_samples
            
            # Clip to valid range
            start_sample = max(0, start_sample)
            end_sample = min(len(waveform), end_sample)
            
            segment = waveform[start_sample:end_sample]
            
            # Crop or pad to target length
            segment_len = len(segment)
            
            if segment_len < self.segment_samples:
                # Pad if too short
                padding = self.segment_samples - segment_len
                segment = F.pad(segment.unsqueeze(0), (0, padding)).squeeze(0)
            elif segment_len > self.segment_samples:
                # Random crop if too long
                max_start = segment_len - self.segment_samples
                crop_start = random.randint(0, max_start)
                segment = segment[crop_start:crop_start + self.segment_samples]
            
            return segment
        
        except Exception as e:
            print(f"\nError loading {audio_path}: {e}")
            # Return silence if loading fails
            return torch.zeros(self.segment_samples)
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Returns:
            anchor: (n_mels, frames)
            positive: (n_mels, frames) - different segment from same speaker
            speaker_id: int
        """
        speaker_id, anchor_segment = self.samples[idx]
        
        # Load anchor
        anchor_wav = self._load_audio_segment(anchor_segment)
        
        # Apply audio augmentation to anchor (NEW)
        if self.use_audio_augment and self.audio_augment is not None:
            anchor_wav = self.audio_augment(anchor_wav, prob=0.5)
        
        anchor_mel = self.audio_processor(anchor_wav)
        
        # Apply spec augmentation
        if self.apply_augment and self.spec_augment is not None:
            anchor_mel = self.spec_augment(anchor_mel)
        
        # Load positive (different segment from same speaker)
        positive_segment = random.choice(self.speaker_segments[speaker_id])
        
        # Ensure it's different from anchor if possible
        if len(self.speaker_segments[speaker_id]) > 1:
            max_attempts = 10
            for _ in range(max_attempts):
                positive_segment = random.choice(self.speaker_segments[speaker_id])
                if positive_segment['start'] != anchor_segment['start']:
                    break
        
        positive_wav = self._load_audio_segment(positive_segment)
        
        # Apply DIFFERENT audio augmentation to positive (NEW)
        if self.use_audio_augment and self.audio_augment is not None:
            positive_wav = self.audio_augment(positive_wav, prob=0.5)
        
        positive_mel = self.audio_processor(positive_wav)
        
        # Apply different spec augmentation to positive
        if self.apply_augment and self.spec_augment is not None:
            positive_mel = self.spec_augment(positive_mel)
        
        # Convert speaker_id to numeric
        speaker_idx = self.speakers.index(speaker_id)
        
        return anchor_mel, positive_mel, speaker_idx

# 3. MODEL ARCHITECTURE

In [6]:
class ConformerBlock(nn.Module):
    """Single Conformer block with multi-head attention and convolution"""
    def __init__(self, d_model: int, n_heads: int, conv_kernel: int = 31, dropout: float = 0.1):
        super().__init__()
        
        # Feed-forward module 1
        self.ff1 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        # Multi-head self-attention
        self.norm_attn = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.dropout_attn = nn.Dropout(dropout)
        
        # Convolution module
        self.norm_conv = nn.LayerNorm(d_model)
        self.conv = nn.Sequential(
            nn.Conv1d(d_model, d_model * 2, 1),
            nn.GLU(dim=1),
            nn.Conv1d(d_model, d_model, conv_kernel, padding=conv_kernel//2, groups=d_model),
            nn.BatchNorm1d(d_model),
            nn.SiLU(),
            nn.Conv1d(d_model, d_model, 1),
            nn.Dropout(dropout)
        )
        
        # Feed-forward module 2
        self.ff2 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model * 4),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        
        self.norm_out = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, time, d_model)
        Returns:
            output: (batch, time, d_model)
        """
        # FF1
        x = x + 0.5 * self.ff1(x)
        
        # Attention
        x_norm = self.norm_attn(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + self.dropout_attn(attn_out)
        
        # Convolution
        x_norm = self.norm_conv(x)
        x_conv = x_norm.transpose(1, 2)  # (batch, d_model, time)
        x_conv = self.conv(x_conv)
        x = x + x_conv.transpose(1, 2)
        
        # FF2
        x = x + 0.5 * self.ff2(x)
        
        return self.norm_out(x)


class ConformerEncoder(nn.Module):
    """Conformer encoder for audio processing"""
    def __init__(self, 
                 input_dim: int = 83,
                 d_model: int = 256,
                 n_layers: int = 8,
                 n_heads: int = 4,
                 conv_kernel: int = 31,
                 dropout: float = 0.1,
                 subsampling_factor: int = 4):
        super().__init__()
        
        # Subsampling layer (reduce frame rate by 4x)
        self.subsampling = nn.Sequential(
            nn.Conv1d(input_dim, d_model, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, dropout)
        
        # Conformer blocks
        self.blocks = nn.ModuleList([
            ConformerBlock(d_model, n_heads, conv_kernel, dropout)
            for _ in range(n_layers)
        ])
        
        self.d_model = d_model
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, n_mels, time)
        Returns:
            output: (batch, time//4, d_model)
        """
        # Subsampling
        x = self.subsampling(x)  # (batch, d_model, time//4)
        x = x.transpose(1, 2)  # (batch, time//4, d_model)
        
        # Positional encoding
        x = self.pos_encoding(x)
        
        # Conformer blocks
        for block in self.blocks:
            x = block(x)
        
        return x


class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding"""
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 10000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class ProjectionHead(nn.Module):
    """Projection head for contrastive learning"""
    def __init__(self, input_dim: int, hidden_dim: int = 256, output_dim: int = 128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class ContraEENDEncoder(nn.Module):
    """Full encoder with projection head"""
    def __init__(self, 
                 input_dim: int = 83,
                 d_model: int = 256,
                 n_layers: int = 8,
                 n_heads: int = 4,
                 projection_dim: int = 128):
        super().__init__()
        
        self.encoder = ConformerEncoder(
            input_dim=input_dim,
            d_model=d_model,
            n_layers=n_layers,
            n_heads=n_heads
        )
        
        # Temporal pooling for utterance-level representation
        self.temporal_pool = nn.AdaptiveAvgPool1d(1)
        
        # Projection head
        self.projection = ProjectionHead(d_model, d_model, projection_dim)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: (batch, n_mels, time)
        Returns:
            embeddings: (batch, d_model) - frame-level pooled
            projections: (batch, projection_dim) - for contrastive loss
        """
        # Encode
        encoded = self.encoder(x)  # (batch, time//4, d_model)
        
        # Temporal pooling
        pooled = self.temporal_pool(encoded.transpose(1, 2)).squeeze(-1)  # (batch, d_model)
        
        # Project
        projected = self.projection(pooled)  # (batch, projection_dim)
        
        # L2 normalize projections
        projected = F.normalize(projected, p=2, dim=1)
        
        return pooled, projected

# 4. CONTRASTIVE LOSS


In [7]:
class SupConLoss(nn.Module):
    """Supervised Contrastive Loss with hard negative mining"""
    def __init__(self, temperature: float = 0.2, base_temperature: float = 0.2):
        super().__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
    
    def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        device = features.device
        batch_size = features.shape[0]
        
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        
        # Compute similarity matrix
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature
        )
        
        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        
        # Remove self-contrast
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        
        # HARD NEGATIVE MINING: Weight harder negatives more
        exp_logits = torch.exp(logits) * logits_mask
        
        # Hard negative mask (only keep negatives, exclude self and positives)
        hard_negative_mask = (1 - mask) * logits_mask
        
        # Weighted by difficulty (higher similarity = harder negative)
        hard_weights = torch.exp(logits) * hard_negative_mask
        hard_weights = hard_weights / (hard_weights.sum(1, keepdim=True) + 1e-8)
        
        # Reweight denominator with hard negatives
        weighted_exp_logits = exp_logits * (1 + hard_weights)
        
        log_prob = logits - torch.log(weighted_exp_logits.sum(1, keepdim=True))
        
        # Compute mean over positives
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
        
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.mean()
        
        return loss

# 5. Training

In [8]:
import csv
from datetime import datetime

class MetricsLogger:
    """CSV logger for tracking training metrics"""
    def __init__(self, log_dir: str = './logs', experiment_name: str = 'contraeend'):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.csv_path = self.log_dir / f'{experiment_name}_{timestamp}.csv'
        
        self.columns = [
            'timestamp', 'epoch', 'phase', 'learning_rate',
            'train_loss', 'val_loss',
            'pos_sim_mean', 'neg_sim_mean', 'separation',
            'accuracy', 'adjusted_rand_score', 'der', 'transition_rate',
            'avg_segment_duration', 'num_segments', 'temporal_stability',
            'per_speaker_accuracy_mean', 'per_speaker_accuracy_std',
            'num_speakers_detected', 'num_speakers_true',
            'has_temporal_modeling', 'oracle_speakers',
            'filename', 'duration',
            'batch_size', 'd_model', 'n_layers', 'temperature',
        ]
        
        with open(self.csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.columns)
            writer.writeheader()
        
        print(f"üìä Metrics logger initialized: {self.csv_path}")
    
    def log_epoch(self, epoch, phase, train_loss, val_loss, learning_rate,
                  embedding_metrics=None, config=None):
        """Log metrics for one epoch"""
        row = {
            'timestamp': datetime.now().isoformat(),
            'epoch': epoch,
            'phase': phase,
            'learning_rate': learning_rate,
            'train_loss': train_loss,
            'val_loss': val_loss,
        }
        
        if embedding_metrics:
            row.update({
                'pos_sim_mean': embedding_metrics.get('pos_sim_mean', ''),
                'neg_sim_mean': embedding_metrics.get('neg_sim_mean', ''),
                'separation': embedding_metrics.get('separation', ''),
            })
        
        if config:
            row.update({
                'batch_size': config.get('batch_size', ''),
                'd_model': config.get('d_model', ''),
                'n_layers': config.get('n_layers', ''),
                'temperature': config.get('temperature', ''),
            })
        
        with open(self.csv_path, 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.columns)
            writer.writerow(row)

In [9]:
class ContrastiveTrainer:
    """Trainer for phase 2 v2: Contrastive Pretraining"""
    def __init__(self,
                 model: nn.Module,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 num_epochs: int,
                 device: str = 'cuda',
                 learning_rate: float = 1e-3,
                 weight_decay: float = 1e-4,
                 temperature: float = 0.2,
                 accumulation_steps: int = 4, 
                 log_dir: str = './logs',
                 config: dict = None):
        
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        
        # Loss
        self.criterion = SupConLoss(temperature=temperature)
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=(0.9, 0.98)
        )
        
        # Warmup + Cosine scheduler
        self.warmup_epochs = 5
        self.total_steps = len(train_loader) * num_epochs
        self.warmup_steps = len(train_loader) * self.warmup_epochs
        self.current_step = 0
        
        def lr_lambda(current_step):
            if current_step < self.warmup_steps:
                # Linear warmup
                return float(current_step) / float(max(1, self.warmup_steps))
            # Cosine decay after warmup
            progress = float(current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
        
        from torch.optim.lr_scheduler import LambdaLR
        self.scheduler = LambdaLR(self.optimizer, lr_lambda)
        
        # Mixed precision
        self.scaler = torch.cuda.amp.GradScaler() if device == 'cuda' else None
        self.use_amp = device == 'cuda'
        
        # Early stopping
        self.best_val_loss = float('inf')
        self.patience = 10
        self.patience_counter = 0
        self.min_delta = 1e-4
        self.accumulation_steps = accumulation_steps
        self.logger = MetricsLogger(log_dir=log_dir, experiment_name='contraeend')
        self.config = config
        print(f"Effective batch size: {train_loader.batch_size * 2 * accumulation_steps}")
    
    
    def train_epoch(self, epoch: int) -> float:
        """Train one epoch with gradient accumulation"""
        self.model.train()
        # Check GPU usage on first batch
        if epoch == 1:
            print("\nüîç Checking GPU usage on first batch...")
        total_loss = 0
        accumulated_loss = 0

        # Timing
        data_time = 0
        forward_time = 0
        backward_time = 0
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch} [Train]')
        for batch_idx, (anchor, positive, labels) in enumerate(pbar):
            t0 = time.time()
            
            anchor = anchor.to(self.device)
            positive = positive.to(self.device)
            labels = labels.to(self.device)
            combined = torch.cat([anchor, positive], dim=0)
            combined_labels = torch.cat([labels, labels], dim=0)
            
            data_time += time.time() - t0
            t1 = time.time()
            
            # Forward
            with torch.cuda.amp.autocast(enabled=self.use_amp):
                _, projections = self.model(combined)
                loss = self.criterion(projections, combined_labels)
                loss = loss / self.accumulation_steps
            
            forward_time += time.time() - t1
            t2 = time.time()
                
            # Backward
            if self.use_amp:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()
            
            accumulated_loss += loss.item()
            backward_time += time.time() - t2
            # Show timing on first epoch
            if epoch == 1 and batch_idx == 10:
                print(f"\n‚è±Ô∏è Timing breakdown (first 10 batches):")
                print(f"  Data loading: {data_time/10:.2f}s per batch")
                print(f"  Forward pass: {forward_time/10:.2f}s per batch")
                print(f"  Backward pass: {backward_time/10:.2f}s per batch")
            
            # Update weights every accumulation_steps
            if (batch_idx + 1) % self.accumulation_steps == 0:
                if self.use_amp:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                    self.optimizer.step()
                
                self.optimizer.zero_grad()
                
                # Step scheduler
                self.current_step += 1
                self.scheduler.step()
                
                total_loss += accumulated_loss
                current_lr = self.optimizer.param_groups[0]['lr']
                pbar.set_postfix({
                    'loss': f'{accumulated_loss:.4f}',
                    'lr': f'{current_lr:.6f}'
                })
                accumulated_loss = 0
        
        avg_loss = total_loss / (len(self.train_loader) // self.accumulation_steps)
        return avg_loss
    
    def validate(self, epoch: int) -> float:
        """Validate"""
        self.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc=f'Epoch {epoch} [Val]')
            for anchor, positive, labels in pbar:
                anchor = anchor.to(self.device)
                positive = positive.to(self.device)
                labels = labels.to(self.device)
                
                combined = torch.cat([anchor, positive], dim=0)
                combined_labels = torch.cat([labels, labels], dim=0)
                
                with torch.cuda.amp.autocast(enabled=self.use_amp):
                    _, projections = self.model(combined)
                    loss = self.criterion(projections, combined_labels)
                
                total_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_loss = total_loss / len(self.val_loader)
        return avg_loss
    
    def compute_embedding_metrics(self, loader: DataLoader) -> Dict[str, float]:
        """Compute cosine similarity metrics for validation"""
        self.model.eval()
        embeddings_list = []
        labels_list = []
        
        with torch.no_grad():
            for batch_idx, (anchor, positive, labels) in enumerate(loader):
                anchor = anchor.to(self.device)
                with torch.cuda.amp.autocast(enabled=self.use_amp):
                    _, anchor_proj = self.model(anchor)
                embeddings_list.append(anchor_proj.cpu())
                labels_list.append(labels)
                
                if batch_idx >= 20:  # Limit samples for speed
                    break
        
        if len(embeddings_list) == 0:
            return {'pos_sim_mean': 0.0, 'neg_sim_mean': 0.0, 'separation': 0.0}
        
        embeddings = torch.cat(embeddings_list, dim=0)  # (N, projection_dim)
        labels = torch.cat(labels_list, dim=0)  # (N,)
        
        # Compute cosine similarity matrix
        sim_matrix = torch.matmul(embeddings, embeddings.T)  # Already normalized
        
        # Same speaker pairs (positive)
        same_speaker_mask = (labels.unsqueeze(0) == labels.unsqueeze(1))
        same_speaker_mask.fill_diagonal_(False)  # Exclude self
        same_speaker_sims = sim_matrix[same_speaker_mask]
        
        # Different speaker pairs (negative)
        diff_speaker_mask = ~same_speaker_mask
        diff_speaker_mask.fill_diagonal_(False)
        diff_speaker_sims = sim_matrix[diff_speaker_mask]
        
        metrics = {
            'pos_sim_mean': same_speaker_sims.mean().item() if len(same_speaker_sims) > 0 else 0.0,
            'neg_sim_mean': diff_speaker_sims.mean().item() if len(diff_speaker_sims) > 0 else 0.0,
            'separation': (same_speaker_sims.mean() - diff_speaker_sims.mean()).item() if len(same_speaker_sims) > 0 else 0.0
        }
        
        return metrics
    
    def save_checkpoint(self, epoch: int, loss: float, filepath: str):
        """Save training checkpoint"""
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
            'best_val_loss': self.best_val_loss,
            'current_step': self.current_step,
            'patience_counter': self.patience_counter
        }, filepath)
        print(f"Checkpoint saved: {filepath}")
    
    def load_checkpoint(self, filepath: str) -> int:
        """Resume from checkpoint"""
        if not os.path.exists(filepath):
            print(f"No checkpoint found at {filepath}")
            return 0
        
        checkpoint = torch.load(filepath, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        self.current_step = checkpoint.get('current_step', 0)
        
        # RESET patience for enhancement training
        self.patience_counter = 0  # Force reset
        
        epoch = checkpoint['epoch']
        print(f"‚úì Resumed from epoch {epoch}, best val loss: {self.best_val_loss:.4f}")
        print(f"‚ö†Ô∏è  Patience counter reset to 0 for enhancement training")
        return epoch
        
    def train(self, num_epochs: int, checkpoint_dir: str = './checkpoints', start_epoch: int = 0):
        """Full training loop with early stopping"""
        Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
        
        for epoch in range(start_epoch + 1, num_epochs + 1):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch}/{num_epochs}")
            print(f"{'='*60}")
            
            # Train
            train_loss = self.train_epoch(epoch)
            print(f"Train Loss: {train_loss:.4f}")
            
            # Validate
            val_loss = self.validate(epoch)
            print(f"Val Loss: {val_loss:.4f}")
            
            # Compute embedding metrics
            metrics = self.compute_embedding_metrics(self.val_loader)
            print(f"Pos Sim: {metrics['pos_sim_mean']:.4f} | Neg Sim: {metrics['neg_sim_mean']:.4f} | Separation: {metrics['separation']:.4f}")
            
            current_lr = self.optimizer.param_groups[0]['lr']
            print(f"Learning Rate: {current_lr:.6f}")

            self.logger.log_epoch(
                epoch=epoch, phase='pretrain', train_loss=train_loss,
                val_loss=val_loss, learning_rate=current_lr,
                embedding_metrics=metrics, config=self.config
            )
            
            # Save checkpoint every epoch
            checkpoint_path = Path(checkpoint_dir) / f'contraeend_epoch_{epoch}.pth'
            self.save_checkpoint(epoch, val_loss, str(checkpoint_path))
            
            # Early stopping check
            if val_loss < self.best_val_loss - self.min_delta:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                best_path = Path(checkpoint_dir) / 'contraeend_best.pth'
                self.save_checkpoint(epoch, val_loss, str(best_path))
                print(f"‚úì New best model saved! Val Loss: {val_loss:.4f}")
            else:
                self.patience_counter += 1
                print(f"Patience: {self.patience_counter}/{self.patience}")
                
                if self.patience_counter >= self.patience:
                    print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch} epochs")
                    break
        
        print(f"\nTraining completed! Best validation loss: {self.best_val_loss:.4f}")


# 6. Training Script

In [10]:
def main():
    """Main training function for phase 2 v2"""
    
    # Set random seed for reproducibility
    set_seed(42)

    # Configuration
    config = {
        'audio_dir': '/kaggle/input/callhome/audio',
        'rttm_dir': '/kaggle/input/callhome/labels',
        'batch_size': 128,  # INCREASED from 64
        'num_epochs': 200,   # Reduced since we have better augmentation
        'learning_rate': 5e-5,
        'weight_decay': 1e-4,
        'temperature': 0.15,  # LOWERED from 0.2 (more strict)
        'segment_length': 3.0,
        'd_model': 128,
        'n_layers': 6,
        'n_heads': 4,
        'projection_dim': 64,
        'num_workers': 4,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'resume_from': '/kaggle/input/contraeend/pytorch/default/1/contraeend_best.pth',  # Resume from best model
        'cache_audio': True,
        'use_audio_augment': True,  # NEW: Enable audio augmentation
    }
    
    print("="*60)
    print("ContraEEND - phase 2 v2: Contrastive Pretraining (FIXED)")
    print("="*60)
    print(f"Device: {config['device']}")
    print(f"Audio dir: {config['audio_dir']}")
    print(f"RTTM dir: {config['rttm_dir']}")
    print(f"Batch Size: {config['batch_size']}")
    print(f"Temperature: {config['temperature']}")
    print(f"Model: Conformer ({config['n_layers']} layers, {config['d_model']} dim)")
    print("="*60)
    
    # Audio processor
    audio_processor = AudioProcessor()
    
    # Dataset with RTTM parsing
    print("\nLoading datasets...")
    full_dataset = callhomeContrastiveDataset(
        audio_dir=config['audio_dir'],
        rttm_dir=config['rttm_dir'],
        audio_processor=audio_processor,
        segment_length=config['segment_length'],
        apply_augment=True,
        cache_audio=config['cache_audio']  # NEW: Pass cache parameter
    )
        
    # Check if dataset is empty
    if len(full_dataset) == 0:
        print("\n‚ùå ERROR: No valid samples found!")
        print("\nPossible issues:")
        print("1. Audio files don't match RTTM file IDs")
        print("2. All segments are too short (< 2.0s)")
        print("3. No speakers have ‚â•2 segments")
        return
    
    # Dataset diagnostics
    print("\n" + "="*60)
    print("üîç DATASET DIAGNOSTICS")
    print("="*60)
    print(f"Total samples: {len(full_dataset)}")
    print(f"Total speakers: {len(full_dataset.speakers)}")
    
    if len(full_dataset.speakers) > 0:
        speaker_counts = {spk: len(segs) for spk, segs in full_dataset.speaker_segments.items()}
        segments_per_speaker = list(speaker_counts.values())
        
        print(f"Avg segments per speaker: {np.mean(segments_per_speaker):.1f}")
        print(f"Min segments per speaker: {min(segments_per_speaker)}")
        print(f"Max segments per speaker: {max(segments_per_speaker)}")
        
        print(f"\nSample speakers: {full_dataset.speakers[:5]}")
    
    # Check if dataset is sufficient
    if len(full_dataset) < 200:
        print("\n‚ö†Ô∏è  WARNING: Dataset is small!")
        print(f"   Current: {len(full_dataset)} samples")
        print(f"   This may limit model performance")
    
    if len(full_dataset.speakers) < 20:
        print("\n‚ö†Ô∏è  WARNING: Few speakers!")
        print(f"   Current: {len(full_dataset.speakers)} speakers")
        print(f"   Contrastive learning works best with 100+ speakers")
    
    print("="*60 + "\n")
    
    # Split for validation (90/10)
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    
    # Adjust batch size and accumulation if needed
    if len(train_dataset) < config['batch_size'] * 4:
        print("\n‚ö†Ô∏è  Small dataset: adjusting batch size and disabling accumulation")
        config['batch_size'] = min(16, len(train_dataset) // 2)
        accumulation_steps = 1
    else:
        accumulation_steps = 4
    
    print(f"Batch size: {config['batch_size']}")
    print(f"Accumulation steps: {accumulation_steps}")
    print(f"Effective batch size: {config['batch_size'] * accumulation_steps * 2}")
    
    # Data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True if config['device'] == 'cuda' else False,
        drop_last=True if len(train_dataset) > config['batch_size'] else False,
        persistent_workers=True if config['num_workers'] > 0 else False  # NEW
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True if config['device'] == 'cuda' else False,
        persistent_workers=True if config['num_workers'] > 0 else False  # NEW
    )
        
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    
    # Model
    print("\nInitializing model...")
    model = ContraEENDEncoder(
        input_dim=83,
        d_model=config['d_model'],
        n_layers=config['n_layers'],
        n_heads=config['n_heads'],
        projection_dim=config['projection_dim']
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Trainer
    print("\nInitializing trainer...")
    trainer = ContrastiveTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=config['num_epochs'],
        device=config['device'],
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        temperature=config['temperature'],
        accumulation_steps=8 # INCREASED from 4: Effective batch = 128*8*2 = 2048
    )
    
    # Resume from checkpoint if specified
    start_epoch = 0
    if config['resume_from'] is not None and os.path.exists(config['resume_from']):
        print(f"\nResuming from checkpoint: {config['resume_from']}")
        start_epoch = trainer.load_checkpoint(config['resume_from'])
        # FORCE RESET PATIENCE
        trainer.patience_counter = 0
        trainer.patience = 15  # more patience
        print(f"‚ö†Ô∏è  Patience reset to 0/15 for enhancement training")
    
    
    # Train
    print("\nStarting training...")
    print("Features enabled:")
    print("="*60)
    
    trainer.train(
        num_epochs=config['num_epochs'],
        start_epoch=start_epoch
    )
    
    print("\n" + "="*60)
    print("== Training Complete! ==")
    print(f"Best Validation Loss: {trainer.best_val_loss:.4f}")
    print("="*60)


if __name__ == '__main__':
    main()

ContraEEND - Phase 1: Contrastive Pretraining (FIXED)
Device: cuda
Audio dir: /kaggle/input/callhome/audio
RTTM dir: /kaggle/input/callhome/labels
Batch Size: 128
Temperature: 0.15
Model: Conformer (6 layers, 128 dim)

Loading datasets...

Building audio-RTTM mapping...

Found 140 audio files
Found 140 RTTM files
Matched 140 audio-RTTM pairs

Example pairs:
  audio_0.wav <-> labels_0.rttm
  audio_1.wav <-> labels_1.rttm
  audio_2.wav <-> labels_2.rttm

Parsing 140 audio-RTTM pairs...


Parsing RTTM: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 140/140 [00:00<00:00, 172.24it/s]


Total segments found: 33471
Speakers after filtering (‚â•2 segments): 289
  Removed 5 speakers with <2 segments

Segment statistics:
  Per speaker - Min: 2, Max: 226, Avg: 115.8
  Duration - Min: 0.0s, Max: 21.1s, Avg: 2.1s

üì¶ Caching audio files in memory...
Loading 140 unique audio files...


Caching audio: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 140/140 [00:41<00:00,  3.34it/s]


‚úì Cached 140 audio files

‚úì Loaded 289 speakers, 12876 segments
  Segment length: 3.0s, Min length: 2.0s
  SpecAugment: enabled
  Audio augmentation: enabled
  Audio caching: enabled

üîç DATASET DIAGNOSTICS
Total samples: 12876
Total speakers: 289
Avg segments per speaker: 115.8
Min segments per speaker: 2
Max segments per speaker: 226

Sample speakers: ['file0_spkA', 'file0_spkB', 'file1_spkA', 'file1_spkB', 'file2_spkB']

Train samples: 11588
Val samples: 1288
Batch size: 128
Accumulation steps: 4
Effective batch size: 1024
Train batches: 90
Val batches: 11

Initializing model...
Total parameters: 2,413,888
Trainable parameters: 2,413,888

Initializing trainer...


  self.scaler = torch.cuda.amp.GradScaler() if device == 'cuda' else None


üìä Metrics logger initialized: logs/contraeend_20251024_033005.csv
Effective batch size: 2048

Resuming from checkpoint: /kaggle/input/contraeend/pytorch/default/1/contraeend_best.pth
‚úì Resumed from epoch 83, best val loss: 1.8993
‚ö†Ô∏è  Patience counter reset to 0 for enhancement training
‚ö†Ô∏è  Patience reset to 0/15 for enhancement training

Starting training...
Features enabled:

Epoch 84/200


  with torch.cuda.amp.autocast(enabled=self.use_amp):
Epoch 84 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.36it/s, loss=2.2362, lr=0.000458]


Train Loss: 2.4208


  with torch.cuda.amp.autocast(enabled=self.use_amp):
Epoch 84 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.32it/s, loss=0.4872]


Val Loss: 2.0381


  with torch.cuda.amp.autocast(enabled=self.use_amp):


Pos Sim: 0.7280 | Neg Sim: 0.0090 | Separation: 0.7190
Learning Rate: 0.000458
Checkpoint saved: checkpoints/contraeend_epoch_84.pth
Patience: 1/15

Epoch 85/200


Epoch 85 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.66it/s, loss=2.0872, lr=0.000457]


Train Loss: 2.1253


Epoch 85 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.36it/s, loss=0.4705]


Val Loss: 1.9136
Pos Sim: 0.7612 | Neg Sim: 0.0057 | Separation: 0.7555
Learning Rate: 0.000457
Checkpoint saved: checkpoints/contraeend_epoch_85.pth
Patience: 2/15

Epoch 86/200


Epoch 86 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.77it/s, loss=1.9109, lr=0.000457]


Train Loss: 2.0034


Epoch 86 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.26it/s, loss=0.4677]


Val Loss: 1.9148
Pos Sim: 0.7660 | Neg Sim: 0.0033 | Separation: 0.7628
Learning Rate: 0.000457
Checkpoint saved: checkpoints/contraeend_epoch_86.pth
Patience: 3/15

Epoch 87/200


Epoch 87 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.68it/s, loss=1.8949, lr=0.000457]


Train Loss: 1.9580


Epoch 87 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.40it/s, loss=0.3797]


Val Loss: 1.8521
Pos Sim: 0.7529 | Neg Sim: 0.0041 | Separation: 0.7488
Learning Rate: 0.000457
Checkpoint saved: checkpoints/contraeend_epoch_87.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.8521

Epoch 88/200


Epoch 88 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.75it/s, loss=1.8997, lr=0.000457]


Train Loss: 1.9174


Epoch 88 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.39it/s, loss=0.4107]


Val Loss: 1.7992
Pos Sim: 0.7847 | Neg Sim: 0.0052 | Separation: 0.7794
Learning Rate: 0.000457
Checkpoint saved: checkpoints/contraeend_epoch_88.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.7992

Epoch 89/200


Epoch 89 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.82it/s, loss=1.8797, lr=0.000456]


Train Loss: 1.8968


Epoch 89 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.36it/s, loss=0.8494]


Val Loss: 1.8457
Pos Sim: 0.7936 | Neg Sim: 0.0027 | Separation: 0.7909
Learning Rate: 0.000456
Checkpoint saved: checkpoints/contraeend_epoch_89.pth
Patience: 1/15

Epoch 90/200


Epoch 90 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.81it/s, loss=1.8758, lr=0.000456]


Train Loss: 1.8691


Epoch 90 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.03it/s, loss=0.5443]


Val Loss: 1.8051
Pos Sim: 0.7762 | Neg Sim: 0.0032 | Separation: 0.7730
Learning Rate: 0.000456
Checkpoint saved: checkpoints/contraeend_epoch_90.pth
Patience: 2/15

Epoch 91/200


Epoch 91 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.86it/s, loss=1.7773, lr=0.000456]


Train Loss: 1.8322


Epoch 91 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:01<00:00,  5.67it/s, loss=0.2604]


Val Loss: 1.7381
Pos Sim: 0.7868 | Neg Sim: 0.0020 | Separation: 0.7849
Learning Rate: 0.000456
Checkpoint saved: checkpoints/contraeend_epoch_91.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.7381

Epoch 92/200


Epoch 92 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.80it/s, loss=1.8168, lr=0.000456]


Train Loss: 1.8273


Epoch 92 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:01<00:00,  5.52it/s, loss=0.7587]


Val Loss: 1.8012
Pos Sim: 0.7826 | Neg Sim: 0.0025 | Separation: 0.7801
Learning Rate: 0.000456
Checkpoint saved: checkpoints/contraeend_epoch_92.pth
Patience: 1/15

Epoch 93/200


Epoch 93 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.77it/s, loss=1.8122, lr=0.000455]


Train Loss: 1.8188


Epoch 93 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.49it/s, loss=0.3023]


Val Loss: 1.7097
Pos Sim: 0.7919 | Neg Sim: 0.0042 | Separation: 0.7877
Learning Rate: 0.000455
Checkpoint saved: checkpoints/contraeend_epoch_93.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.7097

Epoch 94/200


Epoch 94 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.73it/s, loss=1.7847, lr=0.000455]


Train Loss: 1.8189


Epoch 94 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.36it/s, loss=0.3932]


Val Loss: 1.7601
Pos Sim: 0.7806 | Neg Sim: 0.0043 | Separation: 0.7762
Learning Rate: 0.000455
Checkpoint saved: checkpoints/contraeend_epoch_94.pth
Patience: 1/15

Epoch 95/200


Epoch 95 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.79it/s, loss=1.7645, lr=0.000455]


Train Loss: 1.8074


Epoch 95 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.40it/s, loss=0.4560]


Val Loss: 1.7562
Pos Sim: 0.7963 | Neg Sim: 0.0033 | Separation: 0.7930
Learning Rate: 0.000455
Checkpoint saved: checkpoints/contraeend_epoch_95.pth
Patience: 2/15

Epoch 96/200


Epoch 96 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.81it/s, loss=1.7602, lr=0.000454]


Train Loss: 1.7911


Epoch 96 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.22it/s, loss=0.7154]


Val Loss: 1.6901
Pos Sim: 0.8052 | Neg Sim: 0.0022 | Separation: 0.8031
Learning Rate: 0.000454
Checkpoint saved: checkpoints/contraeend_epoch_96.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.6901

Epoch 97/200


Epoch 97 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.75it/s, loss=1.7997, lr=0.000454]


Train Loss: 1.7759


Epoch 97 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.83it/s, loss=0.4194]


Val Loss: 1.7035
Pos Sim: 0.7984 | Neg Sim: 0.0024 | Separation: 0.7960
Learning Rate: 0.000454
Checkpoint saved: checkpoints/contraeend_epoch_97.pth
Patience: 1/15

Epoch 98/200


Epoch 98 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.69it/s, loss=1.7586, lr=0.000454]


Train Loss: 1.7806


Epoch 98 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.40it/s, loss=0.7430]


Val Loss: 1.6965
Pos Sim: 0.8027 | Neg Sim: 0.0008 | Separation: 0.8019
Learning Rate: 0.000454
Checkpoint saved: checkpoints/contraeend_epoch_98.pth
Patience: 2/15

Epoch 99/200


Epoch 99 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:18<00:00,  4.80it/s, loss=1.7315, lr=0.000454]


Train Loss: 1.7507


Epoch 99 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.35it/s, loss=0.4772]


Val Loss: 1.6951
Pos Sim: 0.7960 | Neg Sim: 0.0035 | Separation: 0.7925
Learning Rate: 0.000454
Checkpoint saved: checkpoints/contraeend_epoch_99.pth
Patience: 3/15

Epoch 100/200


Epoch 100 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.60it/s, loss=1.7572, lr=0.000453]


Train Loss: 1.7594


Epoch 100 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.18it/s, loss=1.0919]


Val Loss: 1.7616
Pos Sim: 0.7744 | Neg Sim: 0.0034 | Separation: 0.7710
Learning Rate: 0.000453
Checkpoint saved: checkpoints/contraeend_epoch_100.pth
Patience: 4/15

Epoch 101/200


Epoch 101 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.65it/s, loss=1.7357, lr=0.000453]


Train Loss: 1.7621


Epoch 101 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.24it/s, loss=0.1688]


Val Loss: 1.7012
Pos Sim: 0.7967 | Neg Sim: 0.0046 | Separation: 0.7921
Learning Rate: 0.000453
Checkpoint saved: checkpoints/contraeend_epoch_101.pth
Patience: 5/15

Epoch 102/200


Epoch 102 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.50it/s, loss=1.6995, lr=0.000453]


Train Loss: 1.7562


Epoch 102 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.04it/s, loss=0.4925]


Val Loss: 1.6458
Pos Sim: 0.8169 | Neg Sim: 0.0002 | Separation: 0.8166
Learning Rate: 0.000453
Checkpoint saved: checkpoints/contraeend_epoch_102.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.6458

Epoch 103/200


Epoch 103 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.46it/s, loss=1.6979, lr=0.000452]


Train Loss: 1.7223


Epoch 103 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.09it/s, loss=0.7831]


Val Loss: 1.6801
Pos Sim: 0.7972 | Neg Sim: 0.0019 | Separation: 0.7953
Learning Rate: 0.000452
Checkpoint saved: checkpoints/contraeend_epoch_103.pth
Patience: 1/15

Epoch 104/200


Epoch 104 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.48it/s, loss=1.7335, lr=0.000452]


Train Loss: 1.7235


Epoch 104 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.99it/s, loss=0.4983]


Val Loss: 1.6435
Pos Sim: 0.8063 | Neg Sim: 0.0015 | Separation: 0.8048
Learning Rate: 0.000452
Checkpoint saved: checkpoints/contraeend_epoch_104.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.6435

Epoch 105/200


Epoch 105 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.46it/s, loss=1.7182, lr=0.000452]


Train Loss: 1.7292


Epoch 105 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.99it/s, loss=0.3849]


Val Loss: 1.6339
Pos Sim: 0.8110 | Neg Sim: 0.0007 | Separation: 0.8102
Learning Rate: 0.000452
Checkpoint saved: checkpoints/contraeend_epoch_105.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.6339

Epoch 106/200


Epoch 106 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.50it/s, loss=1.7170, lr=0.000452]


Train Loss: 1.7015


Epoch 106 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.09it/s, loss=0.1753]


Val Loss: 1.6531
Pos Sim: 0.7883 | Neg Sim: 0.0008 | Separation: 0.7875
Learning Rate: 0.000452
Checkpoint saved: checkpoints/contraeend_epoch_106.pth
Patience: 1/15

Epoch 107/200


Epoch 107 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.47it/s, loss=1.6393, lr=0.000451]


Train Loss: 1.6987


Epoch 107 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.07it/s, loss=0.3907]


Val Loss: 1.6563
Pos Sim: 0.8127 | Neg Sim: 0.0036 | Separation: 0.8091
Learning Rate: 0.000451
Checkpoint saved: checkpoints/contraeend_epoch_107.pth
Patience: 2/15

Epoch 108/200


Epoch 108 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.43it/s, loss=1.6439, lr=0.000451]


Train Loss: 1.6974


Epoch 108 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.14it/s, loss=0.6497]


Val Loss: 1.6434
Pos Sim: 0.8186 | Neg Sim: 0.0018 | Separation: 0.8167
Learning Rate: 0.000451
Checkpoint saved: checkpoints/contraeend_epoch_108.pth
Patience: 3/15

Epoch 109/200


Epoch 109 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.46it/s, loss=1.7062, lr=0.000451]


Train Loss: 1.6929


Epoch 109 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.99it/s, loss=0.5929]


Val Loss: 1.6524
Pos Sim: 0.7869 | Neg Sim: 0.0015 | Separation: 0.7855
Learning Rate: 0.000451
Checkpoint saved: checkpoints/contraeend_epoch_109.pth
Patience: 4/15

Epoch 110/200


Epoch 110 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.41it/s, loss=1.6424, lr=0.000450]


Train Loss: 1.6746


Epoch 110 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.94it/s, loss=0.3562]


Val Loss: 1.6480
Pos Sim: 0.8055 | Neg Sim: 0.0011 | Separation: 0.8044
Learning Rate: 0.000450
Checkpoint saved: checkpoints/contraeend_epoch_110.pth
Patience: 5/15

Epoch 111/200


Epoch 111 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.37it/s, loss=1.7165, lr=0.000450]


Train Loss: 1.6881


Epoch 111 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.89it/s, loss=0.7499]


Val Loss: 1.6159
Pos Sim: 0.8109 | Neg Sim: 0.0019 | Separation: 0.8090
Learning Rate: 0.000450
Checkpoint saved: checkpoints/contraeend_epoch_111.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.6159

Epoch 112/200


Epoch 112 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.33it/s, loss=1.6399, lr=0.000450]


Train Loss: 1.6725


Epoch 112 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.03it/s, loss=0.3469]


Val Loss: 1.5774
Pos Sim: 0.8210 | Neg Sim: 0.0001 | Separation: 0.8208
Learning Rate: 0.000450
Checkpoint saved: checkpoints/contraeend_epoch_112.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.5774

Epoch 113/200


Epoch 113 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.50it/s, loss=1.7104, lr=0.000449]


Train Loss: 1.6642


Epoch 113 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.87it/s, loss=0.2197]


Val Loss: 1.5374
Pos Sim: 0.7973 | Neg Sim: 0.0003 | Separation: 0.7970
Learning Rate: 0.000449
Checkpoint saved: checkpoints/contraeend_epoch_113.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.5374

Epoch 114/200


Epoch 114 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.37it/s, loss=1.7116, lr=0.000449]


Train Loss: 1.6485


Epoch 114 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.02it/s, loss=0.2343]


Val Loss: 1.5858
Pos Sim: 0.8058 | Neg Sim: 0.0035 | Separation: 0.8023
Learning Rate: 0.000449
Checkpoint saved: checkpoints/contraeend_epoch_114.pth
Patience: 1/15

Epoch 115/200


Epoch 115 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.42it/s, loss=1.6686, lr=0.000449]


Train Loss: 1.6541


Epoch 115 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.86it/s, loss=0.5895]


Val Loss: 1.6121
Pos Sim: 0.8103 | Neg Sim: 0.0011 | Separation: 0.8093
Learning Rate: 0.000449
Checkpoint saved: checkpoints/contraeend_epoch_115.pth
Patience: 2/15

Epoch 116/200


Epoch 116 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.57it/s, loss=1.6313, lr=0.000449]


Train Loss: 1.6376


Epoch 116 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.17it/s, loss=0.3218]


Val Loss: 1.5620
Pos Sim: 0.8224 | Neg Sim: 0.0002 | Separation: 0.8222
Learning Rate: 0.000449
Checkpoint saved: checkpoints/contraeend_epoch_116.pth
Patience: 3/15

Epoch 117/200


Epoch 117 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.6225, lr=0.000448]


Train Loss: 1.6378


Epoch 117 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.05it/s, loss=0.3343]


Val Loss: 1.5290
Pos Sim: 0.8249 | Neg Sim: 0.0007 | Separation: 0.8242
Learning Rate: 0.000448
Checkpoint saved: checkpoints/contraeend_epoch_117.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.5290

Epoch 118/200


Epoch 118 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.47it/s, loss=1.5966, lr=0.000448]


Train Loss: 1.6202


Epoch 118 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.09it/s, loss=0.2917]


Val Loss: 1.5549
Pos Sim: 0.8294 | Neg Sim: 0.0009 | Separation: 0.8285
Learning Rate: 0.000448
Checkpoint saved: checkpoints/contraeend_epoch_118.pth
Patience: 1/15

Epoch 119/200


Epoch 119 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.6314, lr=0.000448]


Train Loss: 1.6390


Epoch 119 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.17it/s, loss=0.4577]


Val Loss: 1.5541
Pos Sim: 0.8088 | Neg Sim: 0.0001 | Separation: 0.8086
Learning Rate: 0.000448
Checkpoint saved: checkpoints/contraeend_epoch_119.pth
Patience: 2/15

Epoch 120/200


Epoch 120 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.57it/s, loss=1.6786, lr=0.000447]


Train Loss: 1.6053


Epoch 120 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.25it/s, loss=0.3096]


Val Loss: 1.5636
Pos Sim: 0.8134 | Neg Sim: 0.0016 | Separation: 0.8118
Learning Rate: 0.000447
Checkpoint saved: checkpoints/contraeend_epoch_120.pth
Patience: 3/15

Epoch 121/200


Epoch 121 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.47it/s, loss=1.5797, lr=0.000447]


Train Loss: 1.6082


Epoch 121 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.03it/s, loss=0.3384]


Val Loss: 1.5561
Pos Sim: 0.8066 | Neg Sim: 0.0020 | Separation: 0.8046
Learning Rate: 0.000447
Checkpoint saved: checkpoints/contraeend_epoch_121.pth
Patience: 4/15

Epoch 122/200


Epoch 122 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.51it/s, loss=1.6158, lr=0.000447]


Train Loss: 1.6050


Epoch 122 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.94it/s, loss=0.2877]


Val Loss: 1.5197
Pos Sim: 0.8233 | Neg Sim: -0.0003 | Separation: 0.8236
Learning Rate: 0.000447
Checkpoint saved: checkpoints/contraeend_epoch_122.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.5197

Epoch 123/200


Epoch 123 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.5530, lr=0.000446]


Train Loss: 1.5927


Epoch 123 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.08it/s, loss=0.5895]


Val Loss: 1.5457
Pos Sim: 0.8148 | Neg Sim: 0.0011 | Separation: 0.8137
Learning Rate: 0.000446
Checkpoint saved: checkpoints/contraeend_epoch_123.pth
Patience: 1/15

Epoch 124/200


Epoch 124 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.55it/s, loss=1.5548, lr=0.000446]


Train Loss: 1.5899


Epoch 124 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.88it/s, loss=0.2197]


Val Loss: 1.5067
Pos Sim: 0.8174 | Neg Sim: 0.0011 | Separation: 0.8163
Learning Rate: 0.000446
Checkpoint saved: checkpoints/contraeend_epoch_124.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.5067

Epoch 125/200


Epoch 125 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.61it/s, loss=1.5706, lr=0.000446]


Train Loss: 1.5862


Epoch 125 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.27it/s, loss=0.3940]


Val Loss: 1.5360
Pos Sim: 0.8196 | Neg Sim: 0.0028 | Separation: 0.8168
Learning Rate: 0.000446
Checkpoint saved: checkpoints/contraeend_epoch_125.pth
Patience: 1/15

Epoch 126/200


Epoch 126 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.61it/s, loss=1.6131, lr=0.000446]


Train Loss: 1.5768


Epoch 126 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.31it/s, loss=0.3238]


Val Loss: 1.5426
Pos Sim: 0.8236 | Neg Sim: 0.0006 | Separation: 0.8230
Learning Rate: 0.000446
Checkpoint saved: checkpoints/contraeend_epoch_126.pth
Patience: 2/15

Epoch 127/200


Epoch 127 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.58it/s, loss=1.5179, lr=0.000445]


Train Loss: 1.5622


Epoch 127 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.21it/s, loss=0.4712]


Val Loss: 1.5642
Pos Sim: 0.8130 | Neg Sim: 0.0004 | Separation: 0.8126
Learning Rate: 0.000445
Checkpoint saved: checkpoints/contraeend_epoch_127.pth
Patience: 3/15

Epoch 128/200


Epoch 128 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.40it/s, loss=1.5417, lr=0.000445]


Train Loss: 1.5567


Epoch 128 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.07it/s, loss=0.2920]


Val Loss: 1.5408
Pos Sim: 0.8134 | Neg Sim: 0.0004 | Separation: 0.8130
Learning Rate: 0.000445
Checkpoint saved: checkpoints/contraeend_epoch_128.pth
Patience: 4/15

Epoch 129/200


Epoch 129 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.50it/s, loss=1.5337, lr=0.000445]


Train Loss: 1.5625


Epoch 129 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.97it/s, loss=0.7451]


Val Loss: 1.5202
Pos Sim: 0.8338 | Neg Sim: -0.0006 | Separation: 0.8344
Learning Rate: 0.000445
Checkpoint saved: checkpoints/contraeend_epoch_129.pth
Patience: 5/15

Epoch 130/200


Epoch 130 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.58it/s, loss=1.5837, lr=0.000444]


Train Loss: 1.5702


Epoch 130 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.15it/s, loss=0.3870]


Val Loss: 1.5631
Pos Sim: 0.8141 | Neg Sim: 0.0031 | Separation: 0.8110
Learning Rate: 0.000444
Checkpoint saved: checkpoints/contraeend_epoch_130.pth
Patience: 6/15

Epoch 131/200


Epoch 131 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.50it/s, loss=1.6131, lr=0.000444]


Train Loss: 1.5737


Epoch 131 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.12it/s, loss=0.4057]


Val Loss: 1.5423
Pos Sim: 0.8063 | Neg Sim: 0.0001 | Separation: 0.8062
Learning Rate: 0.000444
Checkpoint saved: checkpoints/contraeend_epoch_131.pth
Patience: 7/15

Epoch 132/200


Epoch 132 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.5360, lr=0.000444]


Train Loss: 1.5474


Epoch 132 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.06it/s, loss=0.1852]


Val Loss: 1.4284
Pos Sim: 0.8250 | Neg Sim: -0.0009 | Separation: 0.8259
Learning Rate: 0.000444
Checkpoint saved: checkpoints/contraeend_epoch_132.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.4284

Epoch 133/200


Epoch 133 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.55it/s, loss=1.5617, lr=0.000443]


Train Loss: 1.5632


Epoch 133 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.95it/s, loss=0.2913]


Val Loss: 1.5212
Pos Sim: 0.8187 | Neg Sim: 0.0026 | Separation: 0.8161
Learning Rate: 0.000443
Checkpoint saved: checkpoints/contraeend_epoch_133.pth
Patience: 1/15

Epoch 134/200


Epoch 134 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.64it/s, loss=1.6118, lr=0.000443]


Train Loss: 1.5445


Epoch 134 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.11it/s, loss=0.4324]


Val Loss: 1.4747
Pos Sim: 0.8233 | Neg Sim: 0.0014 | Separation: 0.8219
Learning Rate: 0.000443
Checkpoint saved: checkpoints/contraeend_epoch_134.pth
Patience: 2/15

Epoch 135/200


Epoch 135 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.5647, lr=0.000443]


Train Loss: 1.5553


Epoch 135 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.05it/s, loss=0.3477]


Val Loss: 1.4844
Pos Sim: 0.8155 | Neg Sim: 0.0009 | Separation: 0.8147
Learning Rate: 0.000443
Checkpoint saved: checkpoints/contraeend_epoch_135.pth
Patience: 3/15

Epoch 136/200


Epoch 136 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.52it/s, loss=1.4347, lr=0.000442]


Train Loss: 1.5330


Epoch 136 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.15it/s, loss=0.2395]


Val Loss: 1.5001
Pos Sim: 0.8178 | Neg Sim: 0.0026 | Separation: 0.8152
Learning Rate: 0.000442
Checkpoint saved: checkpoints/contraeend_epoch_136.pth
Patience: 4/15

Epoch 137/200


Epoch 137 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.49it/s, loss=1.5084, lr=0.000442]


Train Loss: 1.5195


Epoch 137 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.02it/s, loss=0.5712]


Val Loss: 1.5663
Pos Sim: 0.8149 | Neg Sim: 0.0029 | Separation: 0.8120
Learning Rate: 0.000442
Checkpoint saved: checkpoints/contraeend_epoch_137.pth
Patience: 5/15

Epoch 138/200


Epoch 138 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.56it/s, loss=1.5346, lr=0.000442]


Train Loss: 1.5244


Epoch 138 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.09it/s, loss=0.5105]


Val Loss: 1.4822
Pos Sim: 0.8357 | Neg Sim: 0.0024 | Separation: 0.8333
Learning Rate: 0.000442
Checkpoint saved: checkpoints/contraeend_epoch_138.pth
Patience: 6/15

Epoch 139/200


Epoch 139 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.5142, lr=0.000441]


Train Loss: 1.4995


Epoch 139 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.16it/s, loss=0.6228]


Val Loss: 1.5074
Pos Sim: 0.8272 | Neg Sim: -0.0004 | Separation: 0.8276
Learning Rate: 0.000441
Checkpoint saved: checkpoints/contraeend_epoch_139.pth
Patience: 7/15

Epoch 140/200


Epoch 140 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.59it/s, loss=1.5457, lr=0.000441]


Train Loss: 1.5369


Epoch 140 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.22it/s, loss=0.2679]


Val Loss: 1.4899
Pos Sim: 0.8148 | Neg Sim: 0.0026 | Separation: 0.8122
Learning Rate: 0.000441
Checkpoint saved: checkpoints/contraeend_epoch_140.pth
Patience: 8/15

Epoch 141/200


Epoch 141 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.5370, lr=0.000441]


Train Loss: 1.5240


Epoch 141 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.98it/s, loss=0.1960]


Val Loss: 1.4727
Pos Sim: 0.8309 | Neg Sim: -0.0007 | Separation: 0.8316
Learning Rate: 0.000441
Checkpoint saved: checkpoints/contraeend_epoch_141.pth
Patience: 9/15

Epoch 142/200


Epoch 142 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.62it/s, loss=1.5342, lr=0.000441]


Train Loss: 1.5012


Epoch 142 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.14it/s, loss=0.3149]


Val Loss: 1.4568
Pos Sim: 0.8228 | Neg Sim: -0.0004 | Separation: 0.8232
Learning Rate: 0.000441
Checkpoint saved: checkpoints/contraeend_epoch_142.pth
Patience: 10/15

Epoch 143/200


Epoch 143 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.52it/s, loss=1.5297, lr=0.000440]


Train Loss: 1.5146


Epoch 143 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.96it/s, loss=0.5472]


Val Loss: 1.5412
Pos Sim: 0.8245 | Neg Sim: 0.0033 | Separation: 0.8212
Learning Rate: 0.000440
Checkpoint saved: checkpoints/contraeend_epoch_143.pth
Patience: 11/15

Epoch 144/200


Epoch 144 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.59it/s, loss=1.4415, lr=0.000440]


Train Loss: 1.5122


Epoch 144 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.16it/s, loss=0.6552]


Val Loss: 1.4764
Pos Sim: 0.8181 | Neg Sim: 0.0000 | Separation: 0.8181
Learning Rate: 0.000440
Checkpoint saved: checkpoints/contraeend_epoch_144.pth
Patience: 12/15

Epoch 145/200


Epoch 145 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.52it/s, loss=1.5058, lr=0.000440]


Train Loss: 1.5016


Epoch 145 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.16it/s, loss=0.1275]


Val Loss: 1.3965
Pos Sim: 0.8556 | Neg Sim: -0.0012 | Separation: 0.8568
Learning Rate: 0.000440
Checkpoint saved: checkpoints/contraeend_epoch_145.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.3965

Epoch 146/200


Epoch 146 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.53it/s, loss=1.4497, lr=0.000439]


Train Loss: 1.4794


Epoch 146 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.02it/s, loss=0.1972]


Val Loss: 1.4612
Pos Sim: 0.8294 | Neg Sim: 0.0014 | Separation: 0.8281
Learning Rate: 0.000439
Checkpoint saved: checkpoints/contraeend_epoch_146.pth
Patience: 1/15

Epoch 147/200


Epoch 147 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.53it/s, loss=1.4953, lr=0.000439]


Train Loss: 1.4847


Epoch 147 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.05it/s, loss=0.6808]


Val Loss: 1.4880
Pos Sim: 0.8446 | Neg Sim: 0.0017 | Separation: 0.8429
Learning Rate: 0.000439
Checkpoint saved: checkpoints/contraeend_epoch_147.pth
Patience: 2/15

Epoch 148/200


Epoch 148 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.49it/s, loss=1.4533, lr=0.000439]


Train Loss: 1.4833


Epoch 148 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.81it/s, loss=0.3807]


Val Loss: 1.4296
Pos Sim: 0.8357 | Neg Sim: 0.0001 | Separation: 0.8356
Learning Rate: 0.000439
Checkpoint saved: checkpoints/contraeend_epoch_148.pth
Patience: 3/15

Epoch 149/200


Epoch 149 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.56it/s, loss=1.5022, lr=0.000438]


Train Loss: 1.4716


Epoch 149 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.18it/s, loss=0.6569]


Val Loss: 1.4898
Pos Sim: 0.8407 | Neg Sim: -0.0004 | Separation: 0.8410
Learning Rate: 0.000438
Checkpoint saved: checkpoints/contraeend_epoch_149.pth
Patience: 4/15

Epoch 150/200


Epoch 150 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.56it/s, loss=1.5257, lr=0.000438]


Train Loss: 1.4678


Epoch 150 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.37it/s, loss=0.1860]


Val Loss: 1.4437
Pos Sim: 0.8449 | Neg Sim: 0.0003 | Separation: 0.8446
Learning Rate: 0.000438
Checkpoint saved: checkpoints/contraeend_epoch_150.pth
Patience: 5/15

Epoch 151/200


Epoch 151 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.48it/s, loss=1.5341, lr=0.000438]


Train Loss: 1.4627


Epoch 151 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.06it/s, loss=0.4383]


Val Loss: 1.3995
Pos Sim: 0.8478 | Neg Sim: -0.0011 | Separation: 0.8489
Learning Rate: 0.000438
Checkpoint saved: checkpoints/contraeend_epoch_151.pth
Patience: 6/15

Epoch 152/200


Epoch 152 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.4901, lr=0.000437]


Train Loss: 1.4599


Epoch 152 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.03it/s, loss=0.2780]


Val Loss: 1.4742
Pos Sim: 0.8216 | Neg Sim: 0.0025 | Separation: 0.8191
Learning Rate: 0.000437
Checkpoint saved: checkpoints/contraeend_epoch_152.pth
Patience: 7/15

Epoch 153/200


Epoch 153 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.50it/s, loss=1.4629, lr=0.000437]


Train Loss: 1.4527


Epoch 153 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.95it/s, loss=0.2254]


Val Loss: 1.4293
Pos Sim: 0.8211 | Neg Sim: 0.0017 | Separation: 0.8194
Learning Rate: 0.000437
Checkpoint saved: checkpoints/contraeend_epoch_153.pth
Patience: 8/15

Epoch 154/200


Epoch 154 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.49it/s, loss=1.4562, lr=0.000437]


Train Loss: 1.4527


Epoch 154 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.14it/s, loss=0.5637]


Val Loss: 1.4649
Pos Sim: 0.8196 | Neg Sim: 0.0021 | Separation: 0.8175
Learning Rate: 0.000437
Checkpoint saved: checkpoints/contraeend_epoch_154.pth
Patience: 9/15

Epoch 155/200


Epoch 155 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.50it/s, loss=1.4544, lr=0.000436]


Train Loss: 1.4668


Epoch 155 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.86it/s, loss=0.6974]


Val Loss: 1.4342
Pos Sim: 0.8278 | Neg Sim: -0.0008 | Separation: 0.8286
Learning Rate: 0.000436
Checkpoint saved: checkpoints/contraeend_epoch_155.pth
Patience: 10/15

Epoch 156/200


Epoch 156 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.55it/s, loss=1.4523, lr=0.000436]


Train Loss: 1.4563


Epoch 156 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.07it/s, loss=0.5806]


Val Loss: 1.4391
Pos Sim: 0.8432 | Neg Sim: -0.0007 | Separation: 0.8438
Learning Rate: 0.000436
Checkpoint saved: checkpoints/contraeend_epoch_156.pth
Patience: 11/15

Epoch 157/200


Epoch 157 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.55it/s, loss=1.4310, lr=0.000436]


Train Loss: 1.4375


Epoch 157 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.24it/s, loss=0.1511]


Val Loss: 1.4427
Pos Sim: 0.8268 | Neg Sim: 0.0009 | Separation: 0.8260
Learning Rate: 0.000436
Checkpoint saved: checkpoints/contraeend_epoch_157.pth
Patience: 12/15

Epoch 158/200


Epoch 158 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.56it/s, loss=1.4403, lr=0.000435]


Train Loss: 1.4529


Epoch 158 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.07it/s, loss=0.2231]


Val Loss: 1.3920
Pos Sim: 0.8391 | Neg Sim: -0.0000 | Separation: 0.8391
Learning Rate: 0.000435
Checkpoint saved: checkpoints/contraeend_epoch_158.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.3920

Epoch 159/200


Epoch 159 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.51it/s, loss=1.4743, lr=0.000435]


Train Loss: 1.4570


Epoch 159 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.12it/s, loss=0.3858]


Val Loss: 1.4015
Pos Sim: 0.8446 | Neg Sim: 0.0011 | Separation: 0.8435
Learning Rate: 0.000435
Checkpoint saved: checkpoints/contraeend_epoch_159.pth
Patience: 1/15

Epoch 160/200


Epoch 160 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.46it/s, loss=1.4443, lr=0.000435]


Train Loss: 1.4451


Epoch 160 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.01it/s, loss=0.1990]


Val Loss: 1.3951
Pos Sim: 0.8398 | Neg Sim: -0.0009 | Separation: 0.8406
Learning Rate: 0.000435
Checkpoint saved: checkpoints/contraeend_epoch_160.pth
Patience: 2/15

Epoch 161/200


Epoch 161 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.45it/s, loss=1.4107, lr=0.000434]


Train Loss: 1.4154


Epoch 161 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.24it/s, loss=0.1151]


Val Loss: 1.3822
Pos Sim: 0.8446 | Neg Sim: -0.0003 | Separation: 0.8449
Learning Rate: 0.000434
Checkpoint saved: checkpoints/contraeend_epoch_161.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.3822

Epoch 162/200


Epoch 162 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.30it/s, loss=1.3856, lr=0.000434]


Train Loss: 1.4183


Epoch 162 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.39it/s, loss=0.2410]


Val Loss: 1.4217
Pos Sim: 0.8264 | Neg Sim: -0.0007 | Separation: 0.8272
Learning Rate: 0.000434
Checkpoint saved: checkpoints/contraeend_epoch_162.pth
Patience: 1/15

Epoch 163/200


Epoch 163 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.29it/s, loss=1.4271, lr=0.000434]


Train Loss: 1.4414


Epoch 163 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.74it/s, loss=0.4136]


Val Loss: 1.3749
Pos Sim: 0.8491 | Neg Sim: -0.0008 | Separation: 0.8499
Learning Rate: 0.000434
Checkpoint saved: checkpoints/contraeend_epoch_163.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.3749

Epoch 164/200


Epoch 164 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.43it/s, loss=1.4123, lr=0.000433]


Train Loss: 1.4160


Epoch 164 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.90it/s, loss=0.3314]


Val Loss: 1.3854
Pos Sim: 0.8387 | Neg Sim: 0.0006 | Separation: 0.8380
Learning Rate: 0.000433
Checkpoint saved: checkpoints/contraeend_epoch_164.pth
Patience: 1/15

Epoch 165/200


Epoch 165 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.44it/s, loss=1.4526, lr=0.000433]


Train Loss: 1.4176


Epoch 165 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.86it/s, loss=0.1493]


Val Loss: 1.3732
Pos Sim: 0.8338 | Neg Sim: 0.0005 | Separation: 0.8333
Learning Rate: 0.000433
Checkpoint saved: checkpoints/contraeend_epoch_165.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.3732

Epoch 166/200


Epoch 166 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.39it/s, loss=1.4882, lr=0.000433]


Train Loss: 1.4242


Epoch 166 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.04it/s, loss=0.1057]


Val Loss: 1.4351
Pos Sim: 0.8349 | Neg Sim: 0.0001 | Separation: 0.8348
Learning Rate: 0.000433
Checkpoint saved: checkpoints/contraeend_epoch_166.pth
Patience: 1/15

Epoch 167/200


Epoch 167 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.47it/s, loss=1.3954, lr=0.000432]


Train Loss: 1.4152


Epoch 167 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.08it/s, loss=0.1620]


Val Loss: 1.4025
Pos Sim: 0.8369 | Neg Sim: 0.0006 | Separation: 0.8363
Learning Rate: 0.000432
Checkpoint saved: checkpoints/contraeend_epoch_167.pth
Patience: 2/15

Epoch 168/200


Epoch 168 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.53it/s, loss=1.4231, lr=0.000432]


Train Loss: 1.4104


Epoch 168 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.04it/s, loss=0.2198]


Val Loss: 1.3463
Pos Sim: 0.8579 | Neg Sim: -0.0005 | Separation: 0.8583
Learning Rate: 0.000432
Checkpoint saved: checkpoints/contraeend_epoch_168.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.3463

Epoch 169/200


Epoch 169 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.43it/s, loss=1.4282, lr=0.000432]


Train Loss: 1.4059


Epoch 169 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.95it/s, loss=0.2672]


Val Loss: 1.3599
Pos Sim: 0.8318 | Neg Sim: -0.0005 | Separation: 0.8324
Learning Rate: 0.000432
Checkpoint saved: checkpoints/contraeend_epoch_169.pth
Patience: 1/15

Epoch 170/200


Epoch 170 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.41it/s, loss=1.3776, lr=0.000431]


Train Loss: 1.3993


Epoch 170 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.05it/s, loss=0.2427]


Val Loss: 1.3953
Pos Sim: 0.8492 | Neg Sim: 0.0009 | Separation: 0.8483
Learning Rate: 0.000431
Checkpoint saved: checkpoints/contraeend_epoch_170.pth
Patience: 2/15

Epoch 171/200


Epoch 171 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.42it/s, loss=1.4774, lr=0.000431]


Train Loss: 1.4097


Epoch 171 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.94it/s, loss=0.0767]


Val Loss: 1.3964
Pos Sim: 0.8287 | Neg Sim: 0.0036 | Separation: 0.8251
Learning Rate: 0.000431
Checkpoint saved: checkpoints/contraeend_epoch_171.pth
Patience: 3/15

Epoch 172/200


Epoch 172 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.55it/s, loss=1.3398, lr=0.000431]


Train Loss: 1.3932


Epoch 172 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.08it/s, loss=0.1060]


Val Loss: 1.3926
Pos Sim: 0.8424 | Neg Sim: 0.0002 | Separation: 0.8422
Learning Rate: 0.000431
Checkpoint saved: checkpoints/contraeend_epoch_172.pth
Patience: 4/15

Epoch 173/200


Epoch 173 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.56it/s, loss=1.4113, lr=0.000430]


Train Loss: 1.3954


Epoch 173 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.06it/s, loss=0.1288]


Val Loss: 1.3933
Pos Sim: 0.8503 | Neg Sim: -0.0011 | Separation: 0.8514
Learning Rate: 0.000430
Checkpoint saved: checkpoints/contraeend_epoch_173.pth
Patience: 5/15

Epoch 174/200


Epoch 174 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.48it/s, loss=1.3727, lr=0.000430]


Train Loss: 1.3833


Epoch 174 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.21it/s, loss=0.2080]


Val Loss: 1.3471
Pos Sim: 0.8347 | Neg Sim: -0.0003 | Separation: 0.8350
Learning Rate: 0.000430
Checkpoint saved: checkpoints/contraeend_epoch_174.pth
Patience: 6/15

Epoch 175/200


Epoch 175 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.51it/s, loss=1.4065, lr=0.000430]


Train Loss: 1.3862


Epoch 175 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.13it/s, loss=0.4552]


Val Loss: 1.3591
Pos Sim: 0.8533 | Neg Sim: -0.0001 | Separation: 0.8534
Learning Rate: 0.000430
Checkpoint saved: checkpoints/contraeend_epoch_175.pth
Patience: 7/15

Epoch 176/200


Epoch 176 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.3578, lr=0.000429]


Train Loss: 1.3997


Epoch 176 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.09it/s, loss=0.5582]


Val Loss: 1.3490
Pos Sim: 0.8490 | Neg Sim: -0.0009 | Separation: 0.8499
Learning Rate: 0.000429
Checkpoint saved: checkpoints/contraeend_epoch_176.pth
Patience: 8/15

Epoch 177/200


Epoch 177 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.51it/s, loss=1.3686, lr=0.000429]


Train Loss: 1.3734


Epoch 177 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.09it/s, loss=0.4290]


Val Loss: 1.4249
Pos Sim: 0.8303 | Neg Sim: 0.0002 | Separation: 0.8301
Learning Rate: 0.000429
Checkpoint saved: checkpoints/contraeend_epoch_177.pth
Patience: 9/15

Epoch 178/200


Epoch 178 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.50it/s, loss=1.4045, lr=0.000429]


Train Loss: 1.3768


Epoch 178 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.10it/s, loss=0.1115]


Val Loss: 1.3048
Pos Sim: 0.8601 | Neg Sim: -0.0006 | Separation: 0.8607
Learning Rate: 0.000429
Checkpoint saved: checkpoints/contraeend_epoch_178.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.3048

Epoch 179/200


Epoch 179 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.46it/s, loss=1.3866, lr=0.000428]


Train Loss: 1.3698


Epoch 179 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.84it/s, loss=0.2852]


Val Loss: 1.3284
Pos Sim: 0.8590 | Neg Sim: -0.0019 | Separation: 0.8609
Learning Rate: 0.000428
Checkpoint saved: checkpoints/contraeend_epoch_179.pth
Patience: 1/15

Epoch 180/200


Epoch 180 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.47it/s, loss=1.3509, lr=0.000428]


Train Loss: 1.3699


Epoch 180 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.08it/s, loss=0.2063]


Val Loss: 1.4062
Pos Sim: 0.8402 | Neg Sim: -0.0000 | Separation: 0.8402
Learning Rate: 0.000428
Checkpoint saved: checkpoints/contraeend_epoch_180.pth
Patience: 2/15

Epoch 181/200


Epoch 181 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.58it/s, loss=1.3562, lr=0.000428]


Train Loss: 1.3786


Epoch 181 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.07it/s, loss=0.0816]


Val Loss: 1.3802
Pos Sim: 0.8373 | Neg Sim: 0.0022 | Separation: 0.8351
Learning Rate: 0.000428
Checkpoint saved: checkpoints/contraeend_epoch_181.pth
Patience: 3/15

Epoch 182/200


Epoch 182 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.49it/s, loss=1.3618, lr=0.000427]


Train Loss: 1.3561


Epoch 182 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.24it/s, loss=0.1379]


Val Loss: 1.3665
Pos Sim: 0.8275 | Neg Sim: -0.0004 | Separation: 0.8279
Learning Rate: 0.000427
Checkpoint saved: checkpoints/contraeend_epoch_182.pth
Patience: 4/15

Epoch 183/200


Epoch 183 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.55it/s, loss=1.3669, lr=0.000427]


Train Loss: 1.3559


Epoch 183 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.15it/s, loss=0.2408]


Val Loss: 1.2953
Pos Sim: 0.8584 | Neg Sim: -0.0013 | Separation: 0.8598
Learning Rate: 0.000427
Checkpoint saved: checkpoints/contraeend_epoch_183.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.2953

Epoch 184/200


Epoch 184 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.48it/s, loss=1.3742, lr=0.000427]


Train Loss: 1.3516


Epoch 184 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.19it/s, loss=0.6997]


Val Loss: 1.3703
Pos Sim: 0.8512 | Neg Sim: -0.0006 | Separation: 0.8518
Learning Rate: 0.000427
Checkpoint saved: checkpoints/contraeend_epoch_184.pth
Patience: 1/15

Epoch 185/200


Epoch 185 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.51it/s, loss=1.3392, lr=0.000426]


Train Loss: 1.3537


Epoch 185 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.01it/s, loss=0.5061]


Val Loss: 1.3936
Pos Sim: 0.8370 | Neg Sim: -0.0009 | Separation: 0.8379
Learning Rate: 0.000426
Checkpoint saved: checkpoints/contraeend_epoch_185.pth
Patience: 2/15

Epoch 186/200


Epoch 186 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.51it/s, loss=1.3502, lr=0.000426]


Train Loss: 1.3474


Epoch 186 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.20it/s, loss=0.1371]


Val Loss: 1.3119
Pos Sim: 0.8433 | Neg Sim: -0.0006 | Separation: 0.8439
Learning Rate: 0.000426
Checkpoint saved: checkpoints/contraeend_epoch_186.pth
Patience: 3/15

Epoch 187/200


Epoch 187 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.4014, lr=0.000425]


Train Loss: 1.3368


Epoch 187 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.14it/s, loss=0.0940]


Val Loss: 1.3310
Pos Sim: 0.8419 | Neg Sim: -0.0004 | Separation: 0.8424
Learning Rate: 0.000425
Checkpoint saved: checkpoints/contraeend_epoch_187.pth
Patience: 4/15

Epoch 188/200


Epoch 188 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.53it/s, loss=1.3763, lr=0.000425]


Train Loss: 1.3385


Epoch 188 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.06it/s, loss=0.5713]


Val Loss: 1.3504
Pos Sim: 0.8512 | Neg Sim: -0.0003 | Separation: 0.8515
Learning Rate: 0.000425
Checkpoint saved: checkpoints/contraeend_epoch_188.pth
Patience: 5/15

Epoch 189/200


Epoch 189 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.59it/s, loss=1.3090, lr=0.000425]


Train Loss: 1.3272


Epoch 189 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.18it/s, loss=0.1840]


Val Loss: 1.3475
Pos Sim: 0.8590 | Neg Sim: -0.0001 | Separation: 0.8590
Learning Rate: 0.000425
Checkpoint saved: checkpoints/contraeend_epoch_189.pth
Patience: 6/15

Epoch 190/200


Epoch 190 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.57it/s, loss=1.4078, lr=0.000424]


Train Loss: 1.3363


Epoch 190 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.09it/s, loss=0.1210]


Val Loss: 1.3371
Pos Sim: 0.8507 | Neg Sim: -0.0011 | Separation: 0.8517
Learning Rate: 0.000424
Checkpoint saved: checkpoints/contraeend_epoch_190.pth
Patience: 7/15

Epoch 191/200


Epoch 191 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.62it/s, loss=1.3106, lr=0.000424]


Train Loss: 1.3565


Epoch 191 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.18it/s, loss=0.1702]


Val Loss: 1.3504
Pos Sim: 0.8451 | Neg Sim: 0.0001 | Separation: 0.8451
Learning Rate: 0.000424
Checkpoint saved: checkpoints/contraeend_epoch_191.pth
Patience: 8/15

Epoch 192/200


Epoch 192 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.40it/s, loss=1.2824, lr=0.000424]


Train Loss: 1.3346


Epoch 192 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.10it/s, loss=0.2650]


Val Loss: 1.3031
Pos Sim: 0.8511 | Neg Sim: -0.0005 | Separation: 0.8516
Learning Rate: 0.000424
Checkpoint saved: checkpoints/contraeend_epoch_192.pth
Patience: 9/15

Epoch 193/200


Epoch 193 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.54it/s, loss=1.3726, lr=0.000423]


Train Loss: 1.3194


Epoch 193 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.02it/s, loss=0.2372]


Val Loss: 1.3311
Pos Sim: 0.8364 | Neg Sim: -0.0008 | Separation: 0.8372
Learning Rate: 0.000423
Checkpoint saved: checkpoints/contraeend_epoch_193.pth
Patience: 10/15

Epoch 194/200


Epoch 194 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.51it/s, loss=1.3523, lr=0.000423]


Train Loss: 1.3229


Epoch 194 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.00it/s, loss=0.5397]


Val Loss: 1.4075
Pos Sim: 0.8553 | Neg Sim: -0.0002 | Separation: 0.8555
Learning Rate: 0.000423
Checkpoint saved: checkpoints/contraeend_epoch_194.pth
Patience: 11/15

Epoch 195/200


Epoch 195 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.50it/s, loss=1.3337, lr=0.000423]


Train Loss: 1.3112


Epoch 195 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.77it/s, loss=0.1800]


Val Loss: 1.2771
Pos Sim: 0.8642 | Neg Sim: -0.0019 | Separation: 0.8661
Learning Rate: 0.000423
Checkpoint saved: checkpoints/contraeend_epoch_195.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.2771

Epoch 196/200


Epoch 196 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:20<00:00,  4.45it/s, loss=1.3333, lr=0.000422]


Train Loss: 1.3196


Epoch 196 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  4.82it/s, loss=0.4547]


Val Loss: 1.3035
Pos Sim: 0.8599 | Neg Sim: -0.0013 | Separation: 0.8612
Learning Rate: 0.000422
Checkpoint saved: checkpoints/contraeend_epoch_196.pth
Patience: 1/15

Epoch 197/200


Epoch 197 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.63it/s, loss=1.3576, lr=0.000422]


Train Loss: 1.3251


Epoch 197 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.16it/s, loss=0.3749]


Val Loss: 1.3823
Pos Sim: 0.8359 | Neg Sim: 0.0007 | Separation: 0.8353
Learning Rate: 0.000422
Checkpoint saved: checkpoints/contraeend_epoch_197.pth
Patience: 2/15

Epoch 198/200


Epoch 198 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.59it/s, loss=1.3209, lr=0.000422]


Train Loss: 1.3189


Epoch 198 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.17it/s, loss=0.2518]


Val Loss: 1.2632
Pos Sim: 0.8615 | Neg Sim: -0.0012 | Separation: 0.8626
Learning Rate: 0.000422
Checkpoint saved: checkpoints/contraeend_epoch_198.pth
Checkpoint saved: checkpoints/contraeend_best.pth
‚úì New best model saved! Val Loss: 1.2632

Epoch 199/200


Epoch 199 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.57it/s, loss=1.2914, lr=0.000421]


Train Loss: 1.3109


Epoch 199 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.22it/s, loss=0.4781]


Val Loss: 1.3712
Pos Sim: 0.8596 | Neg Sim: -0.0005 | Separation: 0.8600
Learning Rate: 0.000421
Checkpoint saved: checkpoints/contraeend_epoch_199.pth
Patience: 1/15

Epoch 200/200


Epoch 200 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 90/90 [00:19<00:00,  4.58it/s, loss=1.3062, lr=0.000421]


Train Loss: 1.3203


Epoch 200 [Val]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 11/11 [00:02<00:00,  5.18it/s, loss=0.4280]


Val Loss: 1.2726
Pos Sim: 0.8543 | Neg Sim: -0.0016 | Separation: 0.8559
Learning Rate: 0.000421
Checkpoint saved: checkpoints/contraeend_epoch_200.pth
Patience: 2/15

Training completed! Best validation loss: 1.2632

== Training Complete! ==
Best Validation Loss: 1.2632


In [11]:
!ls /kaggle/working/checkpoints

contraeend_best.pth	  contraeend_epoch_139.pth  contraeend_epoch_179.pth
contraeend_epoch_100.pth  contraeend_epoch_140.pth  contraeend_epoch_180.pth
contraeend_epoch_101.pth  contraeend_epoch_141.pth  contraeend_epoch_181.pth
contraeend_epoch_102.pth  contraeend_epoch_142.pth  contraeend_epoch_182.pth
contraeend_epoch_103.pth  contraeend_epoch_143.pth  contraeend_epoch_183.pth
contraeend_epoch_104.pth  contraeend_epoch_144.pth  contraeend_epoch_184.pth
contraeend_epoch_105.pth  contraeend_epoch_145.pth  contraeend_epoch_185.pth
contraeend_epoch_106.pth  contraeend_epoch_146.pth  contraeend_epoch_186.pth
contraeend_epoch_107.pth  contraeend_epoch_147.pth  contraeend_epoch_187.pth
contraeend_epoch_108.pth  contraeend_epoch_148.pth  contraeend_epoch_188.pth
contraeend_epoch_109.pth  contraeend_epoch_149.pth  contraeend_epoch_189.pth
contraeend_epoch_110.pth  contraeend_epoch_150.pth  contraeend_epoch_190.pth
contraeend_epoch_111.pth  contraeend_epoch_151.pth  contraeend_epoch_191.pth
con

In [12]:
!cp /kaggle/working/checkpoints/contraeend_best.pth /kaggle/working/

In [13]:
# !zip -r checkpoint.zip /kaggle/working/checkpoints/contraeend_best.pth