In [None]:
import os
import sys
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, accuracy_score, roc_auc_score, classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.models as models

import albumentations as A
from albumentations.pytorch import ToTensorV2

import time
from tqdm.auto import tqdm
import gc

# =====================================================================
# CONFIGURATION
# =====================================================================

class Config:
    # Paths
    TRAIN_CSV = '/kaggle/input/hand-xrays-processed-excel-file/processed_boneage_dataset.csv'
    TRAIN_DIR = '/kaggle/input/hand-xrays/boneage-training-dataset'
    CHECKPOINT_DIR = '/kaggle/working'
    
    # Model
    MODEL_NAME = 'efficientnet_b3'
    IMAGE_SIZE = 512
    BATCH_SIZE = 32
    EPOCHS = 45  # Reduced to 45
    
    # Device
    DEVICE = 'cuda'
    NUM_WORKERS = 2
    PIN_MEMORY = True
    
    # Training
    BASE_LR = 5e-4
    WEIGHT_DECAY = 1e-4
    GRADIENT_CLIP = 1.0
    DROPOUT = 0.4
    
    # Task weights
    AGE_WEIGHT = 0.5
    GENDER_WEIGHT = 0.5
    
    # Optimization
    USE_AMP = True
    ACCUMULATION_STEPS = 1
    
    # Early stopping
    PATIENCE = 15
    MIN_DELTA = 0.001
    
    SEED = 42

config = Config()
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(config.SEED)

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

# =====================================================================
# CALLBACK SYSTEM
# =====================================================================

class Callback:
    """Base callback class"""
    def on_train_start(self, **kwargs):
        pass
    
    def on_train_end(self, **kwargs):
        pass
    
    def on_epoch_start(self, epoch, **kwargs):
        pass
    
    def on_epoch_end(self, epoch, logs=None, **kwargs):
        pass

class ModelCheckpoint(Callback):
    """Save best models based on metrics"""
    def __init__(self, checkpoint_dir, monitor='gender_acc', mode='max', verbose=True):
        self.checkpoint_dir = checkpoint_dir
        self.monitor = monitor
        self.mode = mode
        self.verbose = verbose
        self.best_value = float('-inf') if mode == 'max' else float('inf')
    
    def on_epoch_end(self, epoch, logs=None, model=None, **kwargs):
        if logs is None or model is None:
            return
        
        current = logs.get(self.monitor)
        if current is None:
            return
        
        improved = (self.mode == 'max' and current > self.best_value) or \
                   (self.mode == 'min' and current < self.best_value)
        
        if improved:
            self.best_value = current
            path = f'{self.checkpoint_dir}/best_{self.monitor}.pth'
            torch.save(model.state_dict(), path)
            if self.verbose:
                print(f"  âœ“ Best {self.monitor}: {current:.4f} (saved)")

class EarlyStopping(Callback):
    """Stop training when metric stops improving"""
    def __init__(self, monitor='gender_acc', patience=15, mode='max', min_delta=0.001, verbose=True):
        self.monitor = monitor
        self.patience = patience
        self.mode = mode
        self.min_delta = min_delta
        self.verbose = verbose
        self.best_value = float('-inf') if mode == 'max' else float('inf')
        self.counter = 0
        self.should_stop = False
    
    def on_epoch_end(self, epoch, logs=None, **kwargs):
        if logs is None:
            return
        
        current = logs.get(self.monitor)
        if current is None:
            return
        
        if self.mode == 'max':
            improved = current > self.best_value + self.min_delta
        else:
            improved = current < self.best_value - self.min_delta
        
        if improved:
            self.best_value = current
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
                if self.verbose:
                    print(f"\nâš  Early stopping triggered at epoch {epoch}")

class CSVLogger(Callback):
    """Log metrics to CSV file"""
    def __init__(self, filename):
        self.filename = filename
        self.logs = []
    
    def on_epoch_end(self, epoch, logs=None, **kwargs):
        if logs:
            log_entry = {'epoch': epoch, **logs}
            self.logs.append(log_entry)
            
            # Write to CSV
            df = pd.DataFrame(self.logs)
            df.to_csv(self.filename, index=False)

class PeriodicCheckpoint(Callback):
    """Save full checkpoint every N epochs"""
    def __init__(self, checkpoint_dir, period=5):
        self.checkpoint_dir = checkpoint_dir
        self.period = period
    
    def on_epoch_end(self, epoch, model=None, optimizer=None, scheduler=None, logs=None, **kwargs):
        if epoch % self.period == 0 and model is not None:
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict() if optimizer else None,
                'scheduler': scheduler.state_dict() if scheduler else None,
                'logs': logs
            }, f'{self.checkpoint_dir}/checkpoint_epoch_{epoch}.pth')
            print(f"  ðŸ’¾ Checkpoint saved (epoch {epoch})")

class LRSchedulerCallback(Callback):
    """Step learning rate scheduler"""
    def __init__(self, scheduler):
        self.scheduler = scheduler
    
    def on_epoch_end(self, epoch, **kwargs):
        if self.scheduler:
            self.scheduler.step()

class MetricsDisplay(Callback):
    """Display metrics in formatted way"""
    def on_epoch_end(self, epoch, logs=None, **kwargs):
        if logs:
            print(f"\nðŸ“Š Results:")
            print(f"  Train Loss: {logs.get('train_loss', 0):.4f} | Val Loss: {logs.get('val_loss', 0):.4f}")
            print(f"  Age MAE: {logs.get('age_mae', 0):.4f}")
            print(f"  Gender Acc: {logs.get('gender_acc', 0):.4f} ({logs.get('gender_acc', 0)*100:.2f}%)")
            print(f"  Gender AUC: {logs.get('gender_auc', 0):.4f}")

# =====================================================================
# DATASET
# =====================================================================

class BoneAgeDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, f"{row['id']}.png")
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        if image is None:
            image = np.zeros((config.IMAGE_SIZE, config.IMAGE_SIZE), dtype=np.uint8)
        
        clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
        image = clahe.apply(image)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        
        if self.transform:
            image = self.transform(image=image)['image']
        
        age = torch.tensor(row['BoneAgeNorm'], dtype=torch.float32)
        gender = torch.tensor(row['Gender'], dtype=torch.long)
        
        return image, age, gender

# =====================================================================
# TRANSFORMS
# =====================================================================

def get_train_transforms():
    return A.Compose([
        A.Resize(config.IMAGE_SIZE, config.IMAGE_SIZE),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=20, p=0.7),
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50), p=1),
            A.GaussianBlur(blur_limit=(3, 5), p=1),
            A.MotionBlur(blur_limit=5, p=1),
        ], p=0.3),
        A.OneOf([
            A.OpticalDistortion(distort_limit=0.15, p=1),
            A.GridDistortion(num_steps=5, distort_limit=0.15, p=1),
            A.ElasticTransform(alpha=1, sigma=50, p=1),
        ], p=0.3),
        A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.6),
        A.CLAHE(clip_limit=4.0, p=0.4),
        A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.4),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms():
    return A.Compose([
        A.Resize(config.IMAGE_SIZE, config.IMAGE_SIZE),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

# =====================================================================
# MODEL
# =====================================================================

class BoneAgeGenderModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.backbone = models.efficientnet_b3(pretrained=True)
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        
        self.attention = nn.Sequential(
            nn.Linear(in_features, in_features // 8),
            nn.ReLU(inplace=True),
            nn.Linear(in_features // 8, in_features),
            nn.Sigmoid()
        )
        
        self.shared = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2)
        )
        
        self.age_head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(config.DROPOUT),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)
        )
        
        self.gender_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(config.DROPOUT),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(config.DROPOUT * 0.5),
            nn.Linear(256, 2)
        )
        
    def forward(self, x):
        features = self.backbone(x)
        att_weights = self.attention(features)
        features = features * att_weights
        shared = self.shared(features)
        age = self.age_head(shared).squeeze(-1)
        gender = self.gender_head(shared)
        return age, gender

# =====================================================================
# TRAINING
# =====================================================================

def train_epoch(model, loader, optimizer, age_criterion, gender_criterion, scaler, epoch):
    model.train()
    losses, age_losses, gender_losses = [], [], []
    correct, total = 0, 0
    
    pbar = tqdm(loader, desc=f'Train Epoch {epoch}')
    
    for batch_idx, (images, age_target, gender_target) in enumerate(pbar):
        images = images.to(config.DEVICE, non_blocking=True)
        age_target = age_target.to(config.DEVICE, non_blocking=True)
        gender_target = gender_target.to(config.DEVICE, non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=config.USE_AMP):
            age_pred, gender_pred = model(images)
            age_loss = age_criterion(age_pred, age_target)
            gender_loss = gender_criterion(gender_pred, gender_target)
            loss = config.AGE_WEIGHT * age_loss + config.GENDER_WEIGHT * gender_loss
        
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRADIENT_CLIP)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        losses.append(loss.item())
        age_losses.append(age_loss.item())
        gender_losses.append(gender_loss.item())
        
        preds = gender_pred.argmax(dim=1)
        correct += (preds == gender_target).sum().item()
        total += gender_target.size(0)
        
        if batch_idx % 50 == 0:
            pbar.set_postfix({'loss': f'{np.mean(losses[-50:]):.4f}', 'acc': f'{correct/total:.3f}'})
    
    return np.mean(losses), np.mean(age_losses), np.mean(gender_losses)

def validate_epoch(model, loader, age_criterion, gender_criterion, epoch):
    model.eval()
    losses = []
    age_preds_list, age_targets_list = [], []
    gender_preds_list, gender_targets_list, gender_probs_list = [], [], []
    
    pbar = tqdm(loader, desc=f'Val Epoch {epoch}')
    
    with torch.no_grad():
        for images, age_target, gender_target in pbar:
            images = images.to(config.DEVICE, non_blocking=True)
            age_target = age_target.to(config.DEVICE, non_blocking=True)
            gender_target = gender_target.to(config.DEVICE, non_blocking=True)
            
            with torch.cuda.amp.autocast(enabled=config.USE_AMP):
                age_pred, gender_pred = model(images)
                age_loss = age_criterion(age_pred, age_target)
                gender_loss = gender_criterion(gender_pred, gender_target)
                loss = config.AGE_WEIGHT * age_loss + config.GENDER_WEIGHT * gender_loss
            
            losses.append(loss.item())
            age_preds_list.append(age_pred.cpu().numpy())
            age_targets_list.append(age_target.cpu().numpy())
            
            gender_prob = torch.softmax(gender_pred, dim=1)
            gender_preds_list.append(gender_prob.argmax(dim=1).cpu().numpy())
            gender_targets_list.append(gender_target.cpu().numpy())
            gender_probs_list.append(gender_prob[:, 1].cpu().numpy())
    
    age_preds = np.concatenate(age_preds_list)
    age_targets = np.concatenate(age_targets_list)
    gender_preds = np.concatenate(gender_preds_list)
    gender_targets = np.concatenate(gender_targets_list)
    gender_probs = np.concatenate(gender_probs_list)
    
    age_mae = mean_absolute_error(age_targets, age_preds)
    gender_acc = accuracy_score(gender_targets, gender_preds)
    gender_auc = roc_auc_score(gender_targets, gender_probs)
    
    return np.mean(losses), age_mae, gender_acc, gender_auc

# =====================================================================
# MAIN TRAINING WITH CALLBACKS
# =====================================================================

def main():
    print("\n" + "="*70)
    print("PRODUCTION TRAINING WITH CALLBACKS - 45 EPOCHS")
    print("="*70)
    print(f"Model: {config.MODEL_NAME}")
    print(f"Batch Size: {config.BATCH_SIZE}")
    print(f"Epochs: {config.EPOCHS}")
    print("="*70)
    
    # Load data
    print("\nLoading data...")
    df = pd.read_csv(config.TRAIN_CSV)
    print(f"Total: {len(df)}, Age: {df['boneage'].min():.0f}-{df['boneage'].max():.0f} months")
    print(f"Gender: Female={len(df[df.Gender==0])}, Male={len(df[df.Gender==1])}")
    
    train_df, val_df = train_test_split(df, test_size=0.15, stratify=df['Gender'], random_state=config.SEED)
    print(f"Train: {len(train_df)}, Val: {len(val_df)}")
    
    # Datasets
    train_dataset = BoneAgeDataset(train_df, config.TRAIN_DIR, get_train_transforms())
    val_dataset = BoneAgeDataset(val_df, config.TRAIN_DIR, get_val_transforms())
    
    # Balanced sampling
    gender_counts = train_df['Gender'].value_counts()
    weights = 1.0 / gender_counts
    sample_weights = train_df['Gender'].map(weights).values
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
    
    # Loaders
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, sampler=sampler,
                             num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY,
                             persistent_workers=True, prefetch_factor=2)
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE * 2, shuffle=False,
                           num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY,
                           persistent_workers=True, prefetch_factor=2)
    
    # Model
    print("\nInitializing model...")
    model = BoneAgeGenderModel().to(config.DEVICE)
    
    optimizer = optim.AdamW(model.parameters(), lr=config.BASE_LR, weight_decay=config.WEIGHT_DECAY)
    
    def warmup_cosine(epoch):
        if epoch < 5:
            return (epoch + 1) / 5
        return 0.5 * (1 + np.cos(np.pi * (epoch - 5) / (config.EPOCHS - 5)))
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup_cosine)
    
    age_criterion = nn.SmoothL1Loss()
    gender_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = torch.cuda.amp.GradScaler(enabled=config.USE_AMP)
    
    # Initialize callbacks
    callbacks = [
        ModelCheckpoint(config.CHECKPOINT_DIR, monitor='gender_acc', mode='max'),
        ModelCheckpoint(config.CHECKPOINT_DIR, monitor='age_mae', mode='min'),
        EarlyStopping(monitor='gender_acc', patience=config.PATIENCE, mode='max'),
        CSVLogger(f'{config.CHECKPOINT_DIR}/training_log.csv'),
        PeriodicCheckpoint(config.CHECKPOINT_DIR, period=5),
        LRSchedulerCallback(scheduler),
        MetricsDisplay()
    ]
    
    # Training
    print("\n" + "="*70)
    print("STARTING TRAINING")
    print("="*70)
    
    history = {'train_loss': [], 'val_loss': [], 'age_mae': [], 'gender_acc': [], 'gender_auc': []}
    
    # Call on_train_start
    for callback in callbacks:
        callback.on_train_start(model=model)
    
    for epoch in range(1, config.EPOCHS + 1):
        print(f"\n{'='*70}")
        print(f"EPOCH {epoch}/{config.EPOCHS} | LR: {optimizer.param_groups[0]['lr']:.6f}")
        print(f"{'='*70}")
        
        # Call on_epoch_start
        for callback in callbacks:
            callback.on_epoch_start(epoch, model=model)
        
        # Train & Validate
        train_loss, _, _ = train_epoch(model, train_loader, optimizer, age_criterion, 
                                       gender_criterion, scaler, epoch)
        val_loss, age_mae, gender_acc, gender_auc = validate_epoch(model, val_loader, 
                                                                    age_criterion, gender_criterion, epoch)
        
        # Store history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['age_mae'].append(age_mae)
        history['gender_acc'].append(gender_acc)
        history['gender_auc'].append(gender_auc)
        
        # Prepare logs
        logs = {
            'train_loss': train_loss,
            'val_loss': val_loss,
            'age_mae': age_mae,
            'gender_acc': gender_acc,
            'gender_auc': gender_auc
        }
        
        # Call on_epoch_end
        for callback in callbacks:
            callback.on_epoch_end(epoch, logs=logs, model=model, optimizer=optimizer, scheduler=scheduler)
        
        # Check early stopping
        early_stop = next((c for c in callbacks if isinstance(c, EarlyStopping)), None)
        if early_stop and early_stop.should_stop:
            break
        
        # Memory cleanup
        torch.cuda.empty_cache()
        gc.collect()
    
    # Call on_train_end
    for callback in callbacks:
        callback.on_train_end(model=model)
    
    # Final results
    print("\n" + "="*70)
    print("TRAINING COMPLETED")
    print("="*70)
    print(f"Best Gender Acc: {max(history['gender_acc']):.4f} ({max(history['gender_acc'])*100:.2f}%)")
    print(f"Best Age MAE: {min(history['age_mae']):.4f}")
    print(f"Best Gender AUC: {max(history['gender_auc']):.4f}")
    
    # Plot and evaluate
    plot_history(history)
    
    model.load_state_dict(torch.load(f'{config.CHECKPOINT_DIR}/best_gender_acc.pth'))
    _, final_age_mae, final_gender_acc, final_gender_auc = validate_epoch(
        model, val_loader, age_criterion, gender_criterion, 'Final')
    
    print(f"\nðŸŽ¯ Final Metrics (Best Model):")
    print(f"  Age MAE: {final_age_mae:.4f}")
    print(f"  Gender Acc: {final_gender_acc:.4f} ({final_gender_acc*100:.2f}%)")
    print(f"  Gender AUC: {final_gender_auc:.4f}")
    
    detailed_eval(model, val_loader, val_df)
    
    return model, history

def plot_history(history):
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    axes[0, 0].plot(history['train_loss'], label='Train')
    axes[0, 0].plot(history['val_loss'], label='Val')
    axes[0, 0].set_title('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    axes[0, 1].plot(history['age_mae'])
    axes[0, 1].set_title('Age MAE')
    axes[0, 1].grid(True)
    axes[0, 1].axhline(min(history['age_mae']), color='r', linestyle='--', alpha=0.5)
    
    axes[1, 0].plot(history['gender_acc'])
    axes[1, 0].set_title('Gender Accuracy')
    axes[1, 0].grid(True)
    axes[1, 0].axhline(max(history['gender_acc']), color='r', linestyle='--', alpha=0.5)
    
    axes[1, 1].plot(history['gender_auc'])
    axes[1, 1].set_title('Gender AUC')
    axes[1, 1].grid(True)
    axes[1, 1].axhline(max(history['gender_auc']), color='r', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(f'{config.CHECKPOINT_DIR}/training_history.png', dpi=150)
    plt.show()

def detailed_eval(model, loader, df):
    model.eval()
    all_preds, all_targets = [], []
    
    with torch.no_grad():
        for images, _, gender_target in tqdm(loader, desc='Detailed Eval'):
            images = images.to(config.DEVICE)
            _, gender_pred = model(images)
            preds = gender_pred.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(gender_target.numpy())
    
    print("\nClassification Report:")
    print(classification_report(all_targets, all_preds, target_names=['Female', 'Male']))
    
    cm = confusion_matrix(all_targets, all_preds)
    print("\nConfusion Matrix:")
    print(f"              Predicted")
    print(f"              Female  Male")
    print(f"Actual Female  {cm[0,0]:5d}  {cm[0,1]:4d}")
    print(f"       Male    {cm[1,0]:5d}  {cm[1,1]:4d}")

if __name__ == "__main__":
    model, history = main()

GPU: Tesla P100-PCIE-16GB
Memory: 17.1GB

PRODUCTION TRAINING WITH CALLBACKS - 45 EPOCHS
Model: efficientnet_b3
Batch Size: 32
Epochs: 45

Loading data...
Total: 12611, Age: 1-228 months
Gender: Female=5778, Male=6833
Train: 10719, Val: 1892

Initializing model...


Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-b3899882.pth
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 47.2M/47.2M [00:00<00:00, 209MB/s]



STARTING TRAINING

EPOCH 1/45 | LR: 0.000100


Train Epoch 1:   0%|          | 0/335 [00:00<?, ?it/s]

Val Epoch 1:   0%|          | 0/30 [00:00<?, ?it/s]

  âœ“ Best gender_acc: 0.7627 (saved)
  âœ“ Best age_mae: 0.1410 (saved)

ðŸ“Š Results:
  Train Loss: 0.3477 | Val Loss: 0.2762
  Age MAE: 0.1410
  Gender Acc: 0.7627 (76.27%)
  Gender AUC: 0.8554

EPOCH 2/45 | LR: 0.000200


Train Epoch 2:   0%|          | 0/335 [00:00<?, ?it/s]