In [None]:
import os
import shutil
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score
from google.colab import drive

# ==========================================
# 1. SETUP & FAST COPY SYSTEM 🚀
# ==========================================
# ตั้งค่า Path ใน Drive ของคุณ
DRIVE_ZIP_PATH = '/content/drive/MyDrive/ML_Project/task4.zip'
LOCAL_DIR = '/content/task4_local_data'

print("🚀 Initiating Grandmaster Pipeline...")

# 1.1 Mount Drive
if not os.path.ismount('/content/drive'):
    drive.mount('/content/drive')

# 1.2 Copy & Unzip (ถ้ายังไม่มีข้อมูลในเครื่อง)
if not os.path.exists(LOCAL_DIR):
    print(f"⏳ Speed Boost: กำลัง Copy ไฟล์จาก {DRIVE_ZIP_PATH} ...")

    if not os.path.exists(DRIVE_ZIP_PATH):
        raise FileNotFoundError(f"❌ ไม่เจอไฟล์ {DRIVE_ZIP_PATH} ใน Drive! กรุณาเช็คชื่อไฟล์ครับ")

    # Copy Zip
    os.system(f'cp "{DRIVE_ZIP_PATH}" /content/temp_data.zip')
    print("📦 Unzipping... (Fast SSD)")
    # Unzip
    os.system(f'unzip -q /content/temp_data.zip -d "{LOCAL_DIR}"')
    # Clean up
    os.system('rm /content/temp_data.zip')
    print(f"✅ Data Ready at: {LOCAL_DIR}")
else:
    print(f"⚡ Data Already Exists on SSD")

# 1.3 Auto-Fix Paths
BASE_PATH = LOCAL_DIR
# กรณี Unzip แล้วมีโฟลเดอร์ซ้อน
if not os.path.exists(os.path.join(BASE_PATH, 'train.csv')):
    for root, dirs, files in os.walk(LOCAL_DIR):
        if 'train.csv' in files:
            BASE_PATH = root
            break

print(f"📂 Working Directory: {BASE_PATH}")

TRAIN_CSV = f'{BASE_PATH}/train.csv'
VAL_CSV   = f'{BASE_PATH}/val.csv'
TEST_CSV  = f'{BASE_PATH}/test.csv'

TRAIN_IMG_DIR = f'{BASE_PATH}/train'
VAL_IMG_DIR   = f'{BASE_PATH}/val'
TEST_IMG_DIR  = f'{BASE_PATH}/test'

# ==========================================
# 2. CONFIGURATION
# ==========================================
BATCH_SIZE = 32
EPOCHS = 15
LEARNING_RATE = 1e-4
NUM_CLASSES = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"⚙️ Device: {DEVICE}")

# ==========================================
# 3. DATASET CLASS
# ==========================================
class GameDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, is_test=False):
        self.df = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['file_name'])
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            image = Image.new('RGB', (224, 224), color='black')

        if self.transform:
            image = self.transform(image)

        if self.is_test:
            return image, row['id']
        else:
            return image, int(row['label'])

# ==========================================
# 4. TRANSFORMS (Augmentation)
# ==========================================
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# TTA Transform (สำหรับการทายผลแบบเทพ)
tta_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)), # สุ่ม Crop
    transforms.RandomHorizontalFlip(p=0.5), # สุ่มพลิก
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

print("📦 Loading DataLoaders...")
train_loader = DataLoader(GameDataset(TRAIN_CSV, TRAIN_IMG_DIR, transform=train_transform),
                          batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(GameDataset(VAL_CSV, VAL_IMG_DIR, transform=val_transform),
                        batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# ==========================================
# 5. MODEL & TRAINING (GM Tricks)
# ==========================================
print("🏗️ Building Swin Transformer...")
try:
    model = models.swin_t(weights="DEFAULT")
    model.head = nn.Linear(model.head.in_features, NUM_CLASSES)
except:
    print("⚠️ Swin failed, using ResNet50")
    model = models.resnet50(weights="DEFAULT")
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)

model = model.to(DEVICE)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-2)

# 🔥 GM Trick 1: Label Smoothing (ช่วยให้โมเดลไม่มั่นใจเกินไป Generalize ดีขึ้น)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# 🔥 GM Trick 2: Cosine Annealing Scheduler (ปรับ LR เป็นคลื่น)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)

best_val_f1 = 0.0
print("🔥 Start Training...")

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        # Update Scheduler (Cosine needs step per batch or epoch, usually batch is better for WarmRestarts)
        scheduler.step(epoch + pbar.n / len(train_loader))

        running_loss += loss.item()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    # Validation
    model.eval()
    preds, targs = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            _, p = torch.max(model(images), 1)
            preds.extend(p.cpu().numpy())
            targs.extend(labels.cpu().numpy())

    val_f1 = f1_score(targs, preds, average="macro")
    avg_loss = running_loss / len(train_loader)
    print(f"📊 Epoch {epoch+1}: Loss={avg_loss:.4f} | Val F1={val_f1:.4f}")

    # Save Best Model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), "best_model_task4.pth")
        print("⭐️ New Best Model Saved!")

# ==========================================
# 6. PREDICTION WITH TTA (GM Trick 3) 🔮
# ==========================================
print("\n🔮 Predicting with Test-Time Augmentation (TTA)...")
# โหลดโมเดลที่ดีที่สุด
model.load_state_dict(torch.load("best_model_task4.pth"))
model.eval()

# กำหนดจำนวนรอบ TTA (ยิ่งเยอะยิ่งแม่น แต่ช้าลง)
TTA_STEPS = 5
final_probs = None
test_ids = []

# เราต้องสร้าง Loader ใหม่ที่มี TTA Transform
test_dataset_tta = GameDataset(TEST_CSV, TEST_IMG_DIR, transform=tta_transform, is_test=True)
test_loader_tta = DataLoader(test_dataset_tta, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# เริ่มวนรอบ TTA
for step in range(TTA_STEPS):
    print(f"   🔄 TTA Round {step+1}/{TTA_STEPS}...")

    step_probs = []
    current_ids = []

    with torch.no_grad():
        for images, ids in tqdm(test_loader_tta, leave=False):
            images = images.to(DEVICE)
            outputs = model(images)
            # แปลงเป็น Probability (0-1)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()

            step_probs.append(probs)
            if step == 0: current_ids.extend(ids) # เก็บ ID แค่รอบแรก

    # รวมผล (Concatenate batches)
    step_probs = np.concatenate(step_probs, axis=0)

    # บวกสะสมลงในกองกลาง
    if final_probs is None:
        final_probs = step_probs
        test_ids = current_ids
    else:
        final_probs += step_probs

# เฉลี่ยผลลัพธ์
final_probs /= TTA_STEPS
final_preds = np.argmax(final_probs, axis=1)

# บันทึกไฟล์
sub = pd.DataFrame({"id": test_ids, "label": final_preds})
sub.to_csv("submission_task4.csv", index=False)
print(f"\n🎉 Done! Saved submission_task4.csv with TTA (Best F1: {best_val_f1:.4f})")

🚀 Initiating Grandmaster Pipeline...
Mounted at /content/drive
⏳ Speed Boost: กำลัง Copy ไฟล์จาก /content/drive/MyDrive/ML_Project/task4.zip ...
📦 Unzipping... (Fast SSD)
✅ Data Ready at: /content/task4_local_data
📂 Working Directory: /content/task4_local_data
⚙️ Device: cuda
📦 Loading DataLoaders...
🏗️ Building Swin Transformer...
Downloading: "https://download.pytorch.org/models/swin_t-704ceda3.pth" to /root/.cache/torch/hub/checkpoints/swin_t-704ceda3.pth


100%|██████████| 108M/108M [00:00<00:00, 155MB/s]


🔥 Start Training...


Epoch 1/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 1: Loss=0.6742 | Val F1=0.7372
⭐️ New Best Model Saved!


Epoch 2/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 2: Loss=0.5172 | Val F1=0.7554
⭐️ New Best Model Saved!


Epoch 3/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 3: Loss=0.4730 | Val F1=0.7632
⭐️ New Best Model Saved!


Epoch 4/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 4: Loss=0.4422 | Val F1=0.7607


Epoch 5/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 5: Loss=0.4310 | Val F1=0.7730
⭐️ New Best Model Saved!


Epoch 6/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 6: Loss=0.5098 | Val F1=0.7196


Epoch 7/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 7: Loss=0.4790 | Val F1=0.6913


Epoch 8/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 8: Loss=0.4623 | Val F1=0.7618


Epoch 9/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 9: Loss=0.4464 | Val F1=0.7618


Epoch 10/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 10: Loss=0.4325 | Val F1=0.7901
⭐️ New Best Model Saved!


Epoch 11/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 11: Loss=0.4245 | Val F1=0.7697


Epoch 12/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 12: Loss=0.4153 | Val F1=0.7612


Epoch 13/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 13: Loss=0.4093 | Val F1=0.7703


Epoch 14/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 14: Loss=0.4064 | Val F1=0.7746


Epoch 15/15:   0%|          | 0/986 [00:00<?, ?it/s]

📊 Epoch 15: Loss=0.4042 | Val F1=0.7736

🔮 Predicting with Test-Time Augmentation (TTA)...
   🔄 TTA Round 1/5...


  0%|          | 0/810 [00:00<?, ?it/s]

   🔄 TTA Round 2/5...


  0%|          | 0/810 [00:00<?, ?it/s]

   🔄 TTA Round 3/5...


  0%|          | 0/810 [00:00<?, ?it/s]

   🔄 TTA Round 4/5...


  0%|          | 0/810 [00:00<?, ?it/s]

   🔄 TTA Round 5/5...


  0%|          | 0/810 [00:00<?, ?it/s]


🎉 Done! Saved submission_task4.csv with TTA (Best F1: 0.7901)
