In [2]:
import os
import time
import random
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
import timm
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
)

In [5]:
# ───────────────
# 1. Dataset
# ───────────────
class CustomCelebADataset(Dataset):
    def __init__(self, root, split="train", transform=None, attr_idx=None):
        self.root = root
        self.transform = transform
        self.attr_idx = attr_idx    # integer index of the one attribute you want

        # metadata paths
        attr_path = os.path.join(root, "celeba", "list_attr_celeba.txt")
        split_path = os.path.join(root, "celeba", "list_eval_partition.txt")
        img_folder = os.path.join(root, "celeba", "img_align_celeba")

        # load attributes
        with open(attr_path) as f:
            lines = f.readlines()
        header = lines[1].strip().split()
        data = [l.strip().split() for l in lines[2:]]
        df_attr = pd.DataFrame(data, columns=["filename"] + header)
        df_attr[header] = df_attr[header].astype(int)
        df_attr[header] = (df_attr[header] == 1).astype(int)

        # load train/val/test split
        df_split = pd.read_csv(split_path, delim_whitespace=True,
                               header=None, names=["filename", "split"])
        df = pd.merge(df_attr, df_split, on="filename")
        split_map = {"train": 0, "valid": 1, "test": 2}
        self.df = df[df["split"] == split_map[split]].reset_index(drop=True)

        self.img_folder = img_folder
        self.attr_names = header

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(os.path.join(self.img_folder, row["filename"])).convert("RGB")
        if self.transform:
            img = self.transform(img)
        #labels = torch.tensor(row[self.attr_names].values.astype("float32"))

        all_labels = row[self.attr_names].values.astype("float32")
        single = float(all_labels[self.attr_idx])
        labels = torch.tensor([single])

        return img, labels


In [7]:
# ───────────────
# 2. Transforms
# ───────────────
# ImageNet mean/std (ViT pretrained)
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    # ensure divisible by patch size
    transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

In [9]:
# ───────────────
# 3. DataLoaders
# ───────────────
root = "/Users/sarthakmorj/Downloads/data"
batch_size = 64
num_workers = 0

train_ds = CustomCelebADataset(root, split="train", transform=train_transform, attr_idx=31)
val_ds   = CustomCelebADataset(root, split="valid", transform=val_transform, attr_idx=31)
test_ds  = CustomCelebADataset(root, split="test",  transform=val_transform, attr_idx=31)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=False)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=False)


  df_split = pd.read_csv(split_path, delim_whitespace=True,
  df_split = pd.read_csv(split_path, delim_whitespace=True,
  df_split = pd.read_csv(split_path, delim_whitespace=True,


In [13]:
from transformers import ViTForImageClassification, BitsAndBytesConfig
from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
)

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=1,              # num_labels=1 
    ignore_mismatched_sizes=True,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([1, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [15]:
labels = train_ds.df[train_ds.attr_names[31]].values.astype(int)
pos = labels.sum()
neg = len(labels) - pos
pos_weight = torch.tensor(neg/pos).to(device)

In [17]:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    inference_mode=False,       
    r=16,                       # LoRA rank
    lora_alpha=32,              # LoRA scaling
    target_modules=["query",    # inject into self-attention Q, K, V
                    "key",
                    "value"],
    lora_dropout=0.05,
    bias="none",
    modules_to_save=["classifier"],  # ensures ourclassification head stays trainable
)

# wrapping the HF model with PEFT → only LoRA params will be trainable
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()  # verify only LoRA adapters are unfrozen

trainable params: 885,505 || all params: 86,684,930 || trainable%: 1.0215


In [19]:
# ───────────────
# 5. Loss / Optimizer / Scheduler
# ───────────────
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.01
)

lora_r = 16
lora_alpha = 32
lr=3e-4,
weight_decay=0.01,
epochs = 10
total_steps = len(train_loader) * epochs
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

In [22]:
# ─── Checkpoint setup ───────────────────────────────────────────────────────
import os, torch

checkpoint_path = f"/Users/sarthakmorj/Desktop/checkpoint_bs{batch_size}_lr{lr}_wd{weight_decay}_lora_r{lora_r}_lora_alpha{lora_alpha}.pth"
start_epoch   = 1
start_batch   = 0

if os.path.isfile(checkpoint_path):
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict( ckpt["model_state"] )
    optimizer.load_state_dict( ckpt["opt_state"]   )
    scheduler.load_state_dict( ckpt["sched_state"] )
    start_epoch = ckpt["epoch"]
    start_batch = ckpt["batch_idx"] + 1
    # if we had finished that batch already, move on to the next epoch
    if start_batch >= len(train_loader):
        start_epoch += 1
        start_batch = 0
    print(f"=> Resuming at epoch {start_epoch}, batch {start_batch}")
else:
    print("=> No checkpoint found, starting from scratch")


=> No checkpoint found, starting from scratch


In [24]:
# ───────────────
# 6. Training Loop
# ───────────────
def train_one_epoch(epoch, resume_batch):
    model.train()
    running_loss = 0.0
    t0 = time.time()

    pbar = tqdm(
        train_loader,
        desc=f"Epoch {epoch}/{epochs}",
        total=len(train_loader),
        initial=resume_batch,
        leave=False
    )

    for batch_idx, (imgs, labels) in enumerate(pbar):
        # skip already-done batches
        if epoch == start_epoch and batch_idx < resume_batch:
            continue

        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        logits  = outputs.logits
        loss    = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # update running loss & progress bar
        running_loss += loss.item() * imgs.size(0)
        seen = (pbar.n + 1) * imgs.size(0)
        pbar.set_postfix(train_loss=f"{running_loss/seen:.4f}")

        # checkpoint right after this batch
        torch.save({
            "epoch":       epoch,
            "batch_idx":   batch_idx,
            "model_state": model.state_dict(),
            "opt_state":   optimizer.state_dict(),
            "sched_state": scheduler.state_dict(),
            "hparams": {
                "lr": lr,
                "weight_decay": weight_decay,
                "warmup_steps": warmup_steps,
                "lora_r": lora_r,
                "lora_alpha": lora_alpha
                
            }
        }, checkpoint_path)

    pbar.close()
    avg_loss = running_loss / len(train_ds)
    return avg_loss, time.time() - t0




In [26]:
from sklearn.metrics import roc_curve
@torch.no_grad()
def evaluate_binary(loader):
    model.eval()
    all_y, all_p = [], []

    for imgs, labels in loader:
        imgs = imgs.to(device)
        outputs = model(imgs)
        logits  = outputs.logits.squeeze(-1)

        probs  = torch.sigmoid(logits).cpu().numpy()
        all_p.append(probs)
        all_y.append(labels.view(-1).numpy())

    y_true = np.concatenate(all_y)
    y_prob = np.concatenate(all_p)
    fpr, tpr, thresh = roc_curve(y_true, y_prob)
    best = thresh[(tpr - fpr).argmax()]
    y_pred = (y_prob > best).astype(int)

    acc = accuracy_score(y_true, y_pred)
    p, r, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    roc_auc = roc_auc_score(y_true, y_prob)

    return {"acc": acc, "prec": p, "rec": r, "f1": f1, "roc_auc": roc_auc}

In [28]:
import numpy as np
from tqdm import trange, tqdm

for epoch in trange(start_epoch, epochs + 1, desc="Overall"):
    train_loss, train_time = train_one_epoch(epoch, start_batch)
    # after first epoch, always start batch at 0
    start_batch = 0

    metrics = evaluate_binary(val_loader)
    print(
        f"Epoch {epoch}/{epochs} • "
        f"Train loss={train_loss:.4f} ({train_time:.1f}s) • "
        f"Val Acc={metrics['acc']:.4f} • "
        f"Prec={metrics['prec']:.4f} • "
        f"Recall={metrics['rec']:.4f} • "
        f"F1={metrics['f1']:.4f} • "
        f"AUC={metrics['roc_auc']:.4f}"
    )

Overall:   0%|                                           | 0/10 [00:00<?, ?it/s]
Epoch 1/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 1/10:   0%|                     | 0/79 [00:09<?, ?it/s, train_loss=0.7374][A
Epoch 1/10:   1%|▏            | 1/79 [00:09<12:41,  9.76s/it, train_loss=0.7374][A
Epoch 1/10:   1%|▏            | 1/79 [00:18<12:41,  9.76s/it, train_loss=0.7422][A
Epoch 1/10:   3%|▎            | 2/79 [00:19<12:11,  9.50s/it, train_loss=0.7422][A
Epoch 1/10:   3%|▎            | 2/79 [00:28<12:11,  9.50s/it, train_loss=0.7320][A
Epoch 1/10:   4%|▍            | 3/79 [00:28<11:56,  9.42s/it, train_loss=0.7320][A
Epoch 1/10:   4%|▍            | 3/79 [00:37<11:56,  9.42s/it, train_loss=0.7303][A
Epoch 1/10:   5%|▋            | 4/79 [00:37<11:43,  9.38s/it, train_loss=0.7303][A
Epoch 1/10:   5%|▋            | 4/79 [00:46<11:43,  9.38s/it, train_loss=0.7219][A
Epoch 1/10:   6%|▊            | 5/79 [00:47<11:33,  9.37s/it, train_loss=0.7219

Epoch 1/10 • Train loss=0.5829 (767.2s) • Val Acc=0.8980 • Prec=0.9021 • Recall=0.8833 • F1=0.8926 • AUC=0.9647



Epoch 2/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 2/10:   0%|                     | 0/79 [00:10<?, ?it/s, train_loss=0.2945][A
Epoch 2/10:   1%|▏            | 1/79 [00:10<14:11, 10.91s/it, train_loss=0.2945][A
Epoch 2/10:   1%|▏            | 1/79 [00:21<14:11, 10.91s/it, train_loss=0.2958][A
Epoch 2/10:   3%|▎            | 2/79 [00:21<13:37, 10.62s/it, train_loss=0.2958][A
Epoch 2/10:   3%|▎            | 2/79 [00:31<13:37, 10.62s/it, train_loss=0.2787][A
Epoch 2/10:   4%|▍            | 3/79 [00:32<13:30, 10.66s/it, train_loss=0.2787][A
Epoch 2/10:   4%|▍            | 3/79 [00:42<13:30, 10.66s/it, train_loss=0.2910][A
Epoch 2/10:   5%|▋            | 4/79 [00:42<13:16, 10.62s/it, train_loss=0.2910][A
Epoch 2/10:   5%|▋            | 4/79 [00:52<13:16, 10.62s/it, train_loss=0.2945][A
Epoch 2/10:   6%|▊            | 5/79 [00:52<12:54, 10.46s/it, train_loss=0.2945][A
Epoch 2/10:   6%|▊            | 5/79 [01:02<12:54, 10.46s/it, train_loss=0.

Epoch 2/10 • Train loss=0.2406 (802.0s) • Val Acc=0.9190 • Prec=0.9182 • Recall=0.9125 • F1=0.9154 • AUC=0.9773



Epoch 3/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 3/10:   0%|                     | 0/79 [00:10<?, ?it/s, train_loss=0.1169][A
Epoch 3/10:   1%|▏            | 1/79 [00:10<13:44, 10.57s/it, train_loss=0.1169][A
Epoch 3/10:   1%|▏            | 1/79 [00:20<13:44, 10.57s/it, train_loss=0.1413][A
Epoch 3/10:   3%|▎            | 2/79 [00:20<13:15, 10.34s/it, train_loss=0.1413][A
Epoch 3/10:   3%|▎            | 2/79 [00:30<13:15, 10.34s/it, train_loss=0.1418][A
Epoch 3/10:   4%|▍            | 3/79 [00:30<12:59, 10.26s/it, train_loss=0.1418][A
Epoch 3/10:   4%|▍            | 3/79 [00:40<12:59, 10.26s/it, train_loss=0.1522][A
Epoch 3/10:   5%|▋            | 4/79 [00:41<12:47, 10.24s/it, train_loss=0.1522][A
Epoch 3/10:   5%|▋            | 4/79 [00:51<12:47, 10.24s/it, train_loss=0.1490][A
Epoch 3/10:   6%|▊            | 5/79 [00:51<12:36, 10.22s/it, train_loss=0.1490][A
Epoch 3/10:   6%|▊            | 5/79 [01:01<12:36, 10.22s/it, train_loss=0.

Epoch 3/10 • Train loss=0.2115 (801.4s) • Val Acc=0.9180 • Prec=0.9146 • Recall=0.9146 • F1=0.9146 • AUC=0.9759



Epoch 4/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 4/10:   0%|                     | 0/79 [00:10<?, ?it/s, train_loss=0.2839][A
Epoch 4/10:   1%|▏            | 1/79 [00:10<14:03, 10.82s/it, train_loss=0.2839][A
Epoch 4/10:   1%|▏            | 1/79 [00:20<14:03, 10.82s/it, train_loss=0.2290][A
Epoch 4/10:   3%|▎            | 2/79 [00:21<13:26, 10.48s/it, train_loss=0.2290][A
Epoch 4/10:   3%|▎            | 2/79 [00:31<13:26, 10.48s/it, train_loss=0.2144][A
Epoch 4/10:   4%|▍            | 3/79 [00:31<13:09, 10.39s/it, train_loss=0.2144][A
Epoch 4/10:   4%|▍            | 3/79 [00:41<13:09, 10.39s/it, train_loss=0.1920][A
Epoch 4/10:   5%|▋            | 4/79 [00:41<12:55, 10.34s/it, train_loss=0.1920][A
Epoch 4/10:   5%|▋            | 4/79 [00:51<12:55, 10.34s/it, train_loss=0.1808][A
Epoch 4/10:   6%|▊            | 5/79 [00:51<12:44, 10.33s/it, train_loss=0.1808][A
Epoch 4/10:   6%|▊            | 5/79 [01:01<12:44, 10.33s/it, train_loss=0.

Epoch 4/10 • Train loss=0.1956 (801.7s) • Val Acc=0.9190 • Prec=0.9030 • Recall=0.9313 • F1=0.9169 • AUC=0.9761



Epoch 5/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 5/10:   0%|                     | 0/79 [00:10<?, ?it/s, train_loss=0.3572][A
Epoch 5/10:   1%|▏            | 1/79 [00:10<13:56, 10.72s/it, train_loss=0.3572][A
Epoch 5/10:   1%|▏            | 1/79 [00:20<13:56, 10.72s/it, train_loss=0.3136][A
Epoch 5/10:   3%|▎            | 2/79 [00:21<13:31, 10.54s/it, train_loss=0.3136][A
Epoch 5/10:   3%|▎            | 2/79 [00:31<13:31, 10.54s/it, train_loss=0.2788][A
Epoch 5/10:   4%|▍            | 3/79 [00:31<13:12, 10.42s/it, train_loss=0.2788][A
Epoch 5/10:   4%|▍            | 3/79 [00:41<13:12, 10.42s/it, train_loss=0.2453][A
Epoch 5/10:   5%|▋            | 4/79 [00:41<13:01, 10.43s/it, train_loss=0.2453][A
Epoch 5/10:   5%|▋            | 4/79 [00:52<13:01, 10.43s/it, train_loss=0.2225][A
Epoch 5/10:   6%|▊            | 5/79 [00:52<12:54, 10.46s/it, train_loss=0.2225][A
Epoch 5/10:   6%|▊            | 5/79 [01:02<12:54, 10.46s/it, train_loss=0.

Epoch 5/10 • Train loss=0.1818 (891.0s) • Val Acc=0.9150 • Prec=0.9072 • Recall=0.9167 • F1=0.9119 • AUC=0.9757



Epoch 6/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 6/10:   0%|                     | 0/79 [00:11<?, ?it/s, train_loss=0.0995][A
Epoch 6/10:   1%|▏            | 1/79 [00:12<15:36, 12.01s/it, train_loss=0.0995][A
Epoch 6/10:   1%|▏            | 1/79 [00:22<15:36, 12.01s/it, train_loss=0.1392][A
Epoch 6/10:   3%|▎            | 2/79 [00:22<14:26, 11.25s/it, train_loss=0.1392][A
Epoch 6/10:   3%|▎            | 2/79 [00:33<14:26, 11.25s/it, train_loss=0.1122][A
Epoch 6/10:   4%|▍            | 3/79 [00:34<14:18, 11.29s/it, train_loss=0.1122][A
Epoch 6/10:   4%|▍            | 3/79 [00:44<14:18, 11.29s/it, train_loss=0.1519][A
Epoch 6/10:   5%|▋            | 4/79 [00:45<14:01, 11.22s/it, train_loss=0.1519][A
Epoch 6/10:   5%|▋            | 4/79 [00:56<14:01, 11.22s/it, train_loss=0.1526][A
Epoch 6/10:   6%|▊            | 5/79 [00:56<13:51, 11.24s/it, train_loss=0.1526][A
Epoch 6/10:   6%|▊            | 5/79 [01:07<13:51, 11.24s/it, train_loss=0.

Epoch 6/10 • Train loss=0.1744 (929.7s) • Val Acc=0.9200 • Prec=0.9082 • Recall=0.9271 • F1=0.9175 • AUC=0.9765



Epoch 7/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 7/10:   0%|                     | 0/79 [00:11<?, ?it/s, train_loss=0.1545][A
Epoch 7/10:   1%|▏            | 1/79 [00:12<15:40, 12.05s/it, train_loss=0.1545][A
Epoch 7/10:   1%|▏            | 1/79 [00:23<15:40, 12.05s/it, train_loss=0.1467][A
Epoch 7/10:   3%|▎            | 2/79 [00:23<15:11, 11.84s/it, train_loss=0.1467][A
Epoch 7/10:   3%|▎            | 2/79 [00:35<15:11, 11.84s/it, train_loss=0.1377][A
Epoch 7/10:   4%|▍            | 3/79 [00:35<14:54, 11.76s/it, train_loss=0.1377][A
Epoch 7/10:   4%|▍            | 3/79 [00:47<14:54, 11.76s/it, train_loss=0.1224][A
Epoch 7/10:   5%|▋            | 4/79 [00:47<14:47, 11.84s/it, train_loss=0.1224][A
Epoch 7/10:   5%|▋            | 4/79 [00:58<14:47, 11.84s/it, train_loss=0.1345][A
Epoch 7/10:   6%|▊            | 5/79 [00:59<14:33, 11.80s/it, train_loss=0.1345][A
Epoch 7/10:   6%|▊            | 5/79 [01:10<14:33, 11.80s/it, train_loss=0.

Epoch 7/10 • Train loss=0.1627 (920.1s) • Val Acc=0.9150 • Prec=0.9089 • Recall=0.9146 • F1=0.9117 • AUC=0.9756



Epoch 8/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 8/10:   0%|                     | 0/79 [00:11<?, ?it/s, train_loss=0.2009][A
Epoch 8/10:   1%|▏            | 1/79 [00:12<15:37, 12.02s/it, train_loss=0.2009][A
Epoch 8/10:   1%|▏            | 1/79 [00:23<15:37, 12.02s/it, train_loss=0.2169][A
Epoch 8/10:   3%|▎            | 2/79 [00:23<14:53, 11.61s/it, train_loss=0.2169][A
Epoch 8/10:   3%|▎            | 2/79 [00:35<14:53, 11.61s/it, train_loss=0.1992][A
Epoch 8/10:   4%|▍            | 3/79 [00:35<14:58, 11.83s/it, train_loss=0.1992][A
Epoch 8/10:   4%|▍            | 3/79 [00:46<14:58, 11.83s/it, train_loss=0.1679][A
Epoch 8/10:   5%|▋            | 4/79 [00:47<14:41, 11.75s/it, train_loss=0.1679][A
Epoch 8/10:   5%|▋            | 4/79 [00:58<14:41, 11.75s/it, train_loss=0.1671][A
Epoch 8/10:   6%|▊            | 5/79 [00:58<14:27, 11.73s/it, train_loss=0.1671][A
Epoch 8/10:   6%|▊            | 5/79 [01:10<14:27, 11.73s/it, train_loss=0.

Epoch 8/10 • Train loss=0.1573 (913.1s) • Val Acc=0.9180 • Prec=0.9198 • Recall=0.9083 • F1=0.9140 • AUC=0.9758



Epoch 9/10:   0%|                                        | 0/79 [00:00<?, ?it/s][A
Epoch 9/10:   0%|                     | 0/79 [00:11<?, ?it/s, train_loss=0.0997][A
Epoch 9/10:   1%|▏            | 1/79 [00:11<15:00, 11.55s/it, train_loss=0.0997][A
Epoch 9/10:   1%|▏            | 1/79 [00:23<15:00, 11.55s/it, train_loss=0.1220][A
Epoch 9/10:   3%|▎            | 2/79 [00:23<15:00, 11.69s/it, train_loss=0.1220][A
Epoch 9/10:   3%|▎            | 2/79 [00:34<15:00, 11.69s/it, train_loss=0.1223][A
Epoch 9/10:   4%|▍            | 3/79 [00:35<14:50, 11.72s/it, train_loss=0.1223][A
Epoch 9/10:   4%|▍            | 3/79 [00:47<14:50, 11.72s/it, train_loss=0.1177][A
Epoch 9/10:   5%|▋            | 4/79 [00:47<14:54, 11.93s/it, train_loss=0.1177][A
Epoch 9/10:   5%|▋            | 4/79 [00:58<14:54, 11.93s/it, train_loss=0.1422][A
Epoch 9/10:   6%|▊            | 5/79 [00:59<14:35, 11.84s/it, train_loss=0.1422][A
Epoch 9/10:   6%|▊            | 5/79 [01:10<14:35, 11.84s/it, train_loss=0.

Epoch 9/10 • Train loss=0.1537 (902.0s) • Val Acc=0.9200 • Prec=0.8922 • Recall=0.9479 • F1=0.9192 • AUC=0.9761



Epoch 10/10:   0%|                                       | 0/79 [00:00<?, ?it/s][A
Epoch 10/10:   0%|                    | 0/79 [00:10<?, ?it/s, train_loss=0.0624][A
Epoch 10/10:   1%|▏           | 1/79 [00:11<14:29, 11.15s/it, train_loss=0.0624][A
Epoch 10/10:   1%|▏           | 1/79 [00:21<14:29, 11.15s/it, train_loss=0.1249][A
Epoch 10/10:   3%|▎           | 2/79 [00:22<14:17, 11.13s/it, train_loss=0.1249][A
Epoch 10/10:   3%|▎           | 2/79 [00:33<14:17, 11.13s/it, train_loss=0.1347][A
Epoch 10/10:   4%|▍           | 3/79 [00:33<14:17, 11.29s/it, train_loss=0.1347][A
Epoch 10/10:   4%|▍           | 3/79 [00:44<14:17, 11.29s/it, train_loss=0.1185][A
Epoch 10/10:   5%|▌           | 4/79 [00:45<14:06, 11.29s/it, train_loss=0.1185][A
Epoch 10/10:   5%|▌           | 4/79 [00:56<14:06, 11.29s/it, train_loss=0.1067][A
Epoch 10/10:   6%|▊           | 5/79 [00:56<13:59, 11.34s/it, train_loss=0.1067][A
Epoch 10/10:   6%|▊           | 5/79 [01:07<13:59, 11.34s/it, train_loss=0.

KeyboardInterrupt: 