In [1]:
import os
import gc
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.utils.class_weight import compute_class_weight

import timm

# =============================================================================
# CONFIGURATION FOR 4-CHANNEL (IMAGE + VESSEL MASK) INPUT
# =============================================================================
class CFG:
    # --- MODEL & IMAGE SIZE (Same as your baseline EffNet-B3) ---
    MODEL_NAME = 'efficientnet_b3'
    IMG_SIZE = 384
    BATCH_SIZE = 8

    # --- DATA PATHS ---
    BASE_PATH = "/kaggle/input/aptos2019"
    TRAIN_CSV = os.path.join(BASE_PATH, "train_1.csv")
    VAL_CSV   = os.path.join(BASE_PATH, "valid.csv")
    TRAIN_DIR = os.path.join(BASE_PATH, "train_images", "train_images")
    VAL_DIR   = os.path.join(BASE_PATH, "val_images", "val_images")
    
    # --- PATHS TO YOUR NEW SEGMENTED MASKS ---
    SEG_BASE_PATH = "/kaggle/input/segmentaion-dataset/"
    SEG_TRAIN_DIR = os.path.join(SEG_BASE_PATH, "segmented_outputs_train_1/segmented_outputs_train_1/")
    SEG_VAL_DIR   = os.path.join(SEG_BASE_PATH, "segmented_outputs_val/segmented_outputs_val/")

    # --- TRAINING PIPELINE (Identical to your successful run for fair comparison) ---
    S1_EPOCHS = 15; S1_LR = 1e-4; S1_USE_MIXUP = True
    S2_EPOCHS = 15; S2_LR = 3e-5; S2_USE_MIXUP = False
    
    # --- GENERAL & SAVING ---
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 2
    PATIENCE = 5
    SEED = 42
    LABEL_SMOOTHING = 0.05
    # New save paths for this experiment
    SAVE_PATH_S1 = "best_model_effnet_b3_seg_stage1.pth"
    SAVE_PATH_FINAL = "best_model_effnet_b3_seg_final.pth"

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = True
seed_everything(CFG.SEED)

# =============================================================================
# PREPROCESSING & AUGMENTATIONS (CORRECTED LOGIC)
# =============================================================================
def preprocess_ben_graham(image, output_size):
    # This function only preprocesses the 3-channel image
    try:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        if gray.mean() < 15: 
            image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
        else:
            _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
            contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                largest_contour = max(contours, key=cv2.contourArea)
                x, y, w, h = cv2.boundingRect(largest_contour)
                image = image[y:y+h, x:x+w]
            image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    except Exception: 
        image = cv2.resize(image, (output_size, output_size), interpolation=cv2.INTER_AREA)
    
    b, g, r = cv2.split(image)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    g = clahe.apply(g)
    
    return cv2.merge((b, g, r))

def get_transforms(is_train=True):
    # This pipeline now only contains augmentations. Preprocessing happens before.
    if is_train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.7),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        ])
    else:
        # No augmentations for validation/test
        return None

# =============================================================================
# UPGRADED DATASET (CORRECTED LOGIC)
# =============================================================================
class Dataset4Channel(Dataset):
    def __init__(self, df, img_dir, seg_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.transform = transform
        # The final normalization/tensor conversion is always applied
        self.post_transform = A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5], std=[0.229, 0.224, 0.225, 0.5]),
            ToTensorV2()
        ])

    def __len__(self): 
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['id_code'] + '.png')
        seg_path = os.path.join(self.seg_dir, row['id_code'] + '.png')
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(seg_path, cv2.IMREAD_GRAYSCALE)
        
        # Step 1: Apply preprocessing to the 3-channel RGB image first
        img = preprocess_ben_graham(img, CFG.IMG_SIZE)
        
        # Step 2: Resize the mask to the exact same size to ensure alignment
        mask = cv2.resize(mask, (CFG.IMG_SIZE, CFG.IMG_SIZE), interpolation=cv2.INTER_NEAREST)
        
        # Step 3: Apply geometric and color augmentations to the ALIGNED pair
        if self.transform:
            augmented = self.transform(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask']
            
        # Step 4: Add the mask as the 4th channel
        img_4_channel = np.dstack((img, mask))
        
        # Step 5: Apply final normalization and convert to tensor
        img_4_channel = self.post_transform(image=img_4_channel)['image']
            
        label = torch.tensor(row['diagnosis'], dtype=torch.long)
        return img_4_channel, label

# =============================================================================
# UPGRADED MODEL TO ACCEPT 4 CHANNELS (Unchanged, was already correct)
# =============================================================================
class DualStreamEfficientNetOrdinal(nn.Module):
    def __init__(self, model_name="efficientnet_b3", num_classes=5, pretrained=True):
        super().__init__()
        # --- RGB branch (keep pretrained EfficientNet) ---
        self.rgb_backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='avg')
        rgb_dim = self.rgb_backbone.num_features

        # --- Mask branch (tiny CNN encoder) ---
        self.mask_encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        mask_dim = 32

        # --- Fusion head ---
        fusion_dim = rgb_dim + mask_dim
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(fusion_dim, 256), nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes - 1)
        )

    def forward(self, x):
        # Split input: RGB (first 3 channels), Mask (last channel)
        rgb = x[:, :3, :, :]
        mask = x[:, 3:, :, :]

        # Forward pass
        rgb_feat = self.rgb_backbone(rgb)                  # [B, rgb_dim]
        mask_feat = self.mask_encoder(mask).flatten(1)     # [B, mask_dim]

        fused = torch.cat([rgb_feat, mask_feat], dim=1)    # concat features
        return self.classifier(fused)

# --- Loss functions, training loops, and other utilities are unchanged ---
class WeightedOrdinalFocalLoss(nn.Module):
    def __init__(self, num_classes=5, gamma=2.0, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.num_classes, self.gamma, self.class_weights, self.label_smoothing = num_classes, gamma, class_weights, label_smoothing
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    def forward(self, outputs, targets):
        ordinal_targets = torch.zeros_like(outputs)
        for i, t in enumerate(targets):
            if t > 0: ordinal_targets[i, :t] = 1.0
        if self.label_smoothing > 0.0: ordinal_targets = ordinal_targets * (1.0 - self.label_smoothing) + 0.5 * self.label_smoothing
        bce = self.bce(outputs, ordinal_targets)
        if self.class_weights is not None:
            weights = self.class_weights[targets].view(-1, 1).expand(-1, outputs.shape[1])
            bce = bce * weights
        pt = torch.exp(-bce)
        focal = (1 - pt) ** self.gamma * bce
        return focal.mean()

class SmoothKappaLoss(nn.Module):
    def __init__(self, num_classes=5, eps=1e-7):
        super().__init__()
        self.num_classes, self.eps = num_classes, eps
        W = torch.zeros(num_classes, num_classes)
        for i in range(num_classes):
            for j in range(num_classes): W[i,j] = ((i - j)**2) / ((num_classes - 1)**2)
        self.register_buffer("W", W)
    def forward(self, outputs, targets):
        device = outputs.device; B = outputs.size(0); probs = torch.sigmoid(outputs)
        class_probs = torch.zeros(B, self.num_classes, device=device)
        class_probs[:, 0] = 1 - probs[:, 0]
        for k in range(1, self.num_classes-1): class_probs[:, k] = probs[:, k-1] - probs[:, k]
        class_probs[:, -1] = probs[:, -1]
        class_probs = torch.clamp(class_probs, min=self.eps, max=1.0)
        one_hot = F.one_hot(targets, num_classes=self.num_classes).float().to(device)
        conf_mat = torch.matmul(one_hot.T, class_probs)
        hist_true = one_hot.sum(dim=0); hist_pred = class_probs.sum(dim=0)
        expected = torch.outer(hist_true, hist_pred)
        W = self.W.to(device); obs = torch.sum(W * conf_mat); exp = torch.sum(W * expected)
        kappa = 1.0 - (B * obs) / (exp + self.eps)
        return 1.0 - kappa

def mixup_data(x, y, alpha=0.4):
    if alpha > 0: lam = np.random.beta(alpha, alpha)
    else: lam = 1
    batch_size = x.size()[0]; index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def ordinal_to_class(outputs): 
    return torch.sum(torch.sigmoid(outputs) > 0.5, dim=1).long()

def calculate_metrics(outputs, targets):
    preds = ordinal_to_class(outputs).cpu().numpy()
    targets_np = targets.cpu().numpy()
    return accuracy_score(targets_np, preds), cohen_kappa_score(targets_np, preds, weights='quadratic')

def clear_memory(): 
    gc.collect()
    torch.cuda.empty_cache()

def train_epoch(model, loader, optimizer, criterion, scaler, device, use_mixup):
    model.train(); running_loss = 0.0; all_out, all_t = [], []
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, targets in pbar:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if use_mixup: images, targets_a, targets_b, lam = mixup_data(images, targets)
        with torch.cuda.amp.autocast():
            outputs = model(images)
            if use_mixup: loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            else: loss = criterion(outputs, targets)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        running_loss += loss.item(); all_out.append(outputs.detach()); all_t.append(targets.detach())
        pbar.set_postfix(loss=loss.item())
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

def validate_epoch(model, loader, criterion, device):
    model.eval(); running_loss = 0.0; all_out, all_t = [], []
    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating", leave=False)
        for images, targets in pbar:
            images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, targets)
            running_loss += loss.item()
            all_out.append(outputs)
            all_t.append(targets)
    all_out, all_t = torch.cat(all_out), torch.cat(all_t)
    return running_loss / len(loader), *calculate_metrics(all_out, all_t)

def main():
    print(f"Device: {CFG.DEVICE}, Model: {CFG.MODEL_NAME} (4-Channel), Image Size: {CFG.IMG_SIZE}")
    train_df = pd.read_csv(CFG.TRAIN_CSV)
    val_df = pd.read_csv(CFG.VAL_CSV)
    
    train_tf = get_transforms(is_train=True)
    val_tf = get_transforms(is_train=False)

    train_ds = Dataset4Channel(train_df, CFG.TRAIN_DIR, CFG.SEG_TRAIN_DIR, transform=train_tf)
    val_ds   = Dataset4Channel(val_df, CFG.VAL_DIR, CFG.SEG_VAL_DIR, transform=val_tf)

    class_weights_sampler = compute_class_weight('balanced', classes=np.unique(train_df['diagnosis']), y=train_df['diagnosis'])
    sample_weights = np.array([class_weights_sampler[int(l)] for l in train_df['diagnosis']])
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, sampler=sampler, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE*2, shuffle=False, num_workers=CFG.NUM_WORKERS, pin_memory=True)
    
    model = DualStreamEfficientNetOrdinal(CFG.MODEL_NAME).to(CFG.DEVICE)
    class_weights_loss = torch.tensor(class_weights_sampler, dtype=torch.float).to(CFG.DEVICE)
    focal_loss = WeightedOrdinalFocalLoss(num_classes=5, gamma=2.0, class_weights=class_weights_loss, label_smoothing=CFG.LABEL_SMOOTHING)
    kappa_loss = SmoothKappaLoss(num_classes=5)
    
    def hybrid_loss(outputs, targets): 
        return 0.7 * kappa_loss(outputs, targets) + 0.3 * focal_loss(outputs, targets)
    
    scaler = torch.cuda.amp.GradScaler()

    # --- STAGE 1 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 1 (4-Channel)\n" + "="*50)
    opt = optim.AdamW(model.parameters(), lr=CFG.S1_LR, weight_decay=1e-4)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CFG.S1_EPOCHS)
    best_val_qwk, patience_counter = -1, 0

    for epoch in range(CFG.S1_EPOCHS):
        clear_memory()
        print(f"\nEpoch {epoch+1}/{CFG.S1_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, focal_loss, scaler, CFG.DEVICE, CFG.S1_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, focal_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk:
            print(f"Val QWK improved from {best_val_qwk:.4f} to {val_qwk:.4f}. Saving model...")
            best_val_qwk, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_S1)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: 
                print("Early stopping in Stage 1.")
                break
    
    # --- STAGE 2 ---
    print("\n" + "="*50 + "\n     STARTING STAGE 2 (4-Channel)\n" + "="*50)
    if os.path.exists(CFG.SAVE_PATH_S1):
        model.load_state_dict(torch.load(CFG.SAVE_PATH_S1))
    else:
        print("No Stage 1 model was saved. Continuing with the current model.")

    opt = optim.AdamW(model.parameters(), lr=CFG.S2_LR, weight_decay=1e-5)
    sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=CFG.S2_EPOCHS)
    best_val_qwk_stage2, patience_counter = best_val_qwk, 0

    for epoch in range(CFG.S2_EPOCHS):
        clear_memory()
        print(f"\nEpoch {epoch+1}/{CFG.S2_EPOCHS}")
        train_loss, train_acc, train_qwk = train_epoch(model, train_loader, opt, hybrid_loss, scaler, CFG.DEVICE, CFG.S2_USE_MIXUP)
        val_loss, val_acc, val_qwk = validate_epoch(model, val_loader, hybrid_loss, CFG.DEVICE)
        sched.step()
        print(f"Train -> Loss:{train_loss:.4f} Acc:{train_acc:.4f} QWK:{train_qwk:.4f}")
        print(f"Valid -> Loss:{val_loss:.4f} Acc:{val_acc:.4f} QWK:{val_qwk:.4f}")
        if val_qwk > best_val_qwk_stage2:
            print(f"Val QWK improved from {best_val_qwk_stage2:.4f} to {val_qwk:.4f}. Saving final model...")
            best_val_qwk_stage2, patience_counter = val_qwk, 0
            torch.save(model.state_dict(), CFG.SAVE_PATH_FINAL)
        else:
            patience_counter += 1
            if patience_counter >= CFG.PATIENCE: 
                print("Early stopping in Stage 2.")
                break

    print(f"\nTraining Finished!\nFinal Best QWK: {best_val_qwk_stage2:.4f}")

if __name__ == "__main__":
    main()



Device: cuda, Model: efficientnet_b3 (4-Channel), Image Size: 384


  original_init(self, **validated_kwargs)


model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]


     STARTING STAGE 1 (4-Channel)

Epoch 1/15


  scaler = torch.cuda.amp.GradScaler()


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.4256 Acc:0.2416 QWK:0.1768
Valid -> Loss:0.1539 Acc:0.2049 QWK:0.6408
Val QWK improved from -1.0000 to 0.6408. Saving model...

Epoch 2/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.3467 Acc:0.3150 QWK:0.3239
Valid -> Loss:0.1462 Acc:0.1940 QWK:0.5640

Epoch 3/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.3244 Acc:0.3567 QWK:0.3513
Valid -> Loss:0.1392 Acc:0.2022 QWK:0.6446
Val QWK improved from 0.6408 to 0.6446. Saving model...

Epoch 4/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2743 Acc:0.3898 QWK:0.3980
Valid -> Loss:0.1471 Acc:0.2514 QWK:0.6881
Val QWK improved from 0.6446 to 0.6881. Saving model...

Epoch 5/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2600 Acc:0.4017 QWK:0.4168
Valid -> Loss:0.1403 Acc:0.2158 QWK:0.6644

Epoch 6/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2404 Acc:0.3959 QWK:0.4114
Valid -> Loss:0.1420 Acc:0.2186 QWK:0.6657

Epoch 7/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2313 Acc:0.4218 QWK:0.4524
Valid -> Loss:0.1412 Acc:0.2486 QWK:0.6343

Epoch 8/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2114 Acc:0.4263 QWK:0.4556
Valid -> Loss:0.1276 Acc:0.2186 QWK:0.6774

Epoch 9/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2150 Acc:0.4126 QWK:0.4353
Valid -> Loss:0.1335 Acc:0.2541 QWK:0.6955
Val QWK improved from 0.6881 to 0.6955. Saving model...

Epoch 10/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.1925 Acc:0.4549 QWK:0.5055
Valid -> Loss:0.1329 Acc:0.2678 QWK:0.6656

Epoch 11/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2007 Acc:0.4338 QWK:0.4488
Valid -> Loss:0.1316 Acc:0.2678 QWK:0.6556

Epoch 12/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2076 Acc:0.4321 QWK:0.4612
Valid -> Loss:0.1331 Acc:0.2896 QWK:0.6872

Epoch 13/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.1992 Acc:0.4502 QWK:0.4700
Valid -> Loss:0.1314 Acc:0.2923 QWK:0.7030
Val QWK improved from 0.6955 to 0.7030. Saving model...

Epoch 14/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.1863 Acc:0.4225 QWK:0.4509
Valid -> Loss:0.1286 Acc:0.2869 QWK:0.6901

Epoch 15/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.2018 Acc:0.4382 QWK:0.4524
Valid -> Loss:0.1281 Acc:0.3005 QWK:0.6760

     STARTING STAGE 2 (4-Channel)

Epoch 1/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.1650 Acc:0.7140 QWK:0.9097
Valid -> Loss:0.2017 Acc:0.7049 QWK:0.8843
Val QWK improved from 0.7030 to 0.8843. Saving final model...

Epoch 2/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.1088 Acc:0.8232 QWK:0.9488
Valid -> Loss:0.1678 Acc:0.7678 QWK:0.8978
Val QWK improved from 0.8843 to 0.8978. Saving final model...

Epoch 3/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0894 Acc:0.8652 QWK:0.9605
Valid -> Loss:0.1711 Acc:0.7896 QWK:0.9042
Val QWK improved from 0.8978 to 0.9042. Saving final model...

Epoch 4/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0837 Acc:0.8857 QWK:0.9655
Valid -> Loss:0.1663 Acc:0.7732 QWK:0.8959

Epoch 5/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0825 Acc:0.8843 QWK:0.9640
Valid -> Loss:0.1646 Acc:0.7814 QWK:0.9101
Val QWK improved from 0.9042 to 0.9101. Saving final model...

Epoch 6/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0763 Acc:0.9003 QWK:0.9716
Valid -> Loss:0.1651 Acc:0.7896 QWK:0.9048

Epoch 7/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0715 Acc:0.9034 QWK:0.9708
Valid -> Loss:0.1641 Acc:0.7951 QWK:0.9068

Epoch 8/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0729 Acc:0.9017 QWK:0.9706
Valid -> Loss:0.1589 Acc:0.8033 QWK:0.9121
Val QWK improved from 0.9101 to 0.9121. Saving final model...

Epoch 9/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0697 Acc:0.9191 QWK:0.9743
Valid -> Loss:0.1699 Acc:0.8005 QWK:0.9101

Epoch 10/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0617 Acc:0.9222 QWK:0.9764
Valid -> Loss:0.1628 Acc:0.8115 QWK:0.9125
Val QWK improved from 0.9121 to 0.9125. Saving final model...

Epoch 11/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0583 Acc:0.9253 QWK:0.9794
Valid -> Loss:0.1630 Acc:0.8005 QWK:0.9094

Epoch 12/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0602 Acc:0.9208 QWK:0.9787
Valid -> Loss:0.1680 Acc:0.8005 QWK:0.9047

Epoch 13/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0623 Acc:0.9242 QWK:0.9775
Valid -> Loss:0.1737 Acc:0.8033 QWK:0.9082

Epoch 14/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0613 Acc:0.9195 QWK:0.9780
Valid -> Loss:0.1657 Acc:0.8060 QWK:0.9118

Epoch 15/15


Training:   0%|          | 0/367 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Validating:   0%|          | 0/23 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Train -> Loss:0.0628 Acc:0.9253 QWK:0.9774
Valid -> Loss:0.1626 Acc:0.8115 QWK:0.9124
Early stopping in Stage 2.

Training Finished!
Final Best QWK: 0.9125
