In [1]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import transforms, datasets
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score
from torch.optim.lr_scheduler import ReduceLROnPlateau



In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = "/kaggle/input/malimg-original/malimg_paper_dataset_imgs"
BATCH_SIZE = 16  
IMG_SIZE = 224   # DeiT menggunakan 224x224
NUM_CLASSES = 25
EPOCHS = 20
LEARNING_RATE = 5e-5  # Learning rate lebih kecil untuk ViT
OUTPUT_DIR = "/kaggle/working/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [3]:
temp_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])
temp_dataset = datasets.ImageFolder(root=DATA_DIR, transform=temp_transform)
loader = DataLoader(temp_dataset, batch_size=128, shuffle=False, num_workers=2)

mean = 0.
std = 0.
total_samples = 0
for data, _ in loader:
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)
    mean += data.mean(2).sum(0)
    std += data.std(2).sum(0)
    total_samples += batch_samples
mean /= total_samples
std /= total_samples

print(f"Dataset mean: {mean.item():.4f}, std: {std.item():.4f}")

# Data Augmentation untuk training (konsisten dengan SwinV2)
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((int(IMG_SIZE * 1.1), int(IMG_SIZE * 1.1))),
    transforms.RandomResizedCrop(size=(IMG_SIZE, IMG_SIZE), scale=(0.95, 1.0), ratio=(0.95, 1.05)),
    transforms.ToTensor(),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.3),
    transforms.RandomApply([
        transforms.Lambda(lambda x: torch.clamp(x + torch.randn_like(x) * 0.01, 0, 1))
    ], p=0.3),
    transforms.Normalize(mean=mean, std=std)
])

test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

Dataset mean: 0.4455, std: 0.1723


In [4]:
raw_dataset = datasets.ImageFolder(root=DATA_DIR)
all_targets = [label for _, label in raw_dataset.samples]
all_indices = list(range(len(raw_dataset)))

split_path = os.path.join(OUTPUT_DIR, "train_test_split.npz")
if os.path.exists(split_path):
    splits = np.load(split_path)
    train_idx, test_idx = splits['train_idx'], splits['test_idx']
else:
    train_idx, test_idx = train_test_split(
        all_indices,
        test_size=0.2,
        stratify=all_targets,
        random_state=42
    )
    np.savez(split_path, train_idx=train_idx, test_idx=test_idx)

train_dataset_raw = datasets.ImageFolder(root=DATA_DIR, transform=train_transform)
test_dataset_raw = datasets.ImageFolder(root=DATA_DIR, transform=test_transform)

train_dataset = Subset(train_dataset_raw, train_idx)
test_dataset = Subset(test_dataset_raw, test_idx)

# Weighted sampling untuk handle class imbalance
train_targets = [all_targets[i] for i in train_idx]
class_counts = Counter(train_targets)
weights = [1.0 / class_counts[train_targets[i]] for i in range(len(train_targets))]
sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

class_names = raw_dataset.classes


In [5]:
model = timm.create_model('deit_small_patch16_224', pretrained=True, num_classes=NUM_CLASSES)

# Modifikasi patch embedding untuk 1 channel (grayscale)
original_conv = model.patch_embed.proj
model.patch_embed.proj = nn.Conv2d(
    in_channels=1,  # Grayscale input
    out_channels=original_conv.out_channels,
    kernel_size=original_conv.kernel_size,
    stride=original_conv.stride,
    padding=original_conv.padding,
    bias=original_conv.bias is not None
)

# Initialize weights dengan averaging pretrained RGB weights
with torch.no_grad():
    # Average RGB channels ke single grayscale channel
    model.patch_embed.proj.weight[:, 0, :, :] = original_conv.weight.mean(dim=1)
    if model.patch_embed.proj.bias is not None:
        model.patch_embed.proj.bias.copy_(original_conv.bias)

# Modifikasi head dengan dropout
model.head = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.head.in_features, NUM_CLASSES)
)

model.to(DEVICE)

print(f"Model: DeiT Small")
print(f"Input size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Patch size: 16x16")
print(f"Number of classes: {NUM_CLASSES}")

model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Model: DeiT Small
Input size: 224x224
Patch size: 16x16
Number of classes: 25


In [6]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Pastikan inputs adalah 2D [batch_size, num_classes]
        if inputs.dim() > 2:
            inputs = inputs.view(inputs.size(0), -1)
        
        # Pastikan targets adalah 1D tensor dengan tipe long
        targets = targets.view(-1).long()
        
        # Validasi
        num_classes = inputs.size(1)
        assert targets.min() >= 0, f"Target min value {targets.min()} is negative"
        assert targets.max() < num_classes, f"Target max value {targets.max()} >= num_classes {num_classes}"
        
        # Hitung cross entropy dengan log softmax
        log_probs = torch.nn.functional.log_softmax(inputs, dim=-1)
        
        # Gunakan nll_loss untuk lebih aman
        ce_loss = torch.nn.functional.nll_loss(log_probs, targets, reduction='none')
        
        # Hitung probability untuk focal weight
        pt = torch.exp(-ce_loss)
        
        # Hitung focal loss
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        # Apply alpha if specified
        if self.alpha is not None:
            if isinstance(self.alpha, (float, int)):
                alpha_t = self.alpha
            else:
                alpha_t = self.alpha.gather(0, targets)
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

criterion = FocalLoss(gamma=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.05)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)




In [7]:
print("\n" + "="*60)
print("VALIDASI DATA")
print("="*60)
print(f"Number of classes: {len(class_names)}")
print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Test forward pass
print("\nTesting forward pass...")
model.eval()
with torch.no_grad():
    test_batch = next(iter(train_loader))
    test_input, test_label = test_batch[0].to(DEVICE), test_batch[1].to(DEVICE)
    test_output = model(test_input)
    print(f"Input shape: {test_input.shape}")
    print(f"Output shape: {test_output.shape}")
    print(f"Expected output shape: [{test_input.size(0)}, {NUM_CLASSES}]")
    
    if test_output.dim() > 2:
        print(f"Output has {test_output.dim()} dimensions, will reshape during training")

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []
best_val_acc = 0.0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Reshape jika diperlukan
        if outputs.dim() > 2:
            outputs = outputs.view(outputs.size(0), -1)
        
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Gradient clipping untuk stabilitas
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
    
    train_acc = 100 * correct_train / total_train
    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    # Validation
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            
            # Reshape jika diperlukan
            if outputs.dim() > 2:
                outputs = outputs.view(outputs.size(0), -1)
            
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
    
    val_acc = 100 * correct_val / total_val
    val_loss_avg = val_loss / len(test_loader)
    val_losses.append(val_loss_avg)
    val_accuracies.append(val_acc)
    
    # Update learning rate
    scheduler.step(val_loss_avg)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "deit_malimg_best.pth"))
    
    print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {val_loss_avg:.4f}, Val Acc: {val_acc:.2f}%")

# Save final model
torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "deit_malimg_final.pth"))

# Load best model for evaluation
model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, "deit_malimg_best.pth")))

# Evaluasi
print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

model.eval()
all_preds, all_labels, inference_times = [], [], []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        start = time.time()
        outputs = model(inputs)
        
        # Reshape jika diperlukan
        if outputs.dim() > 2:
            outputs = outputs.view(outputs.size(0), -1)
        
        end = time.time()
        inference_times.append(end - start)
        
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_labels = np.array(all_labels)
all_preds = np.array(all_preds)

# Metrik
accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, support = precision_recall_fscore_support(
    all_labels, all_preds, average=None, labels=range(NUM_CLASSES), zero_division=0
)
precision_avg = np.mean(precision)
recall_avg = np.mean(recall)
f1_avg = np.mean(f1)

# Per-class metrics
report_df = pd.DataFrame({
    'class': class_names,
    'precision': precision,
    'recall': recall,
    'f1-score': f1,
    'support': support
})
report_df.to_csv(os.path.join(OUTPUT_DIR, "DeiT_per_class_metrics.csv"), index=False)

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
pd.DataFrame(cm, index=class_names, columns=class_names).to_csv(os.path.join(OUTPUT_DIR, "DeiT_confusion_matrix.csv"))

plt.figure(figsize=(14, 12))
sns.heatmap(cm, annot=False, xticklabels=class_names, yticklabels=class_names, cmap='Blues', fmt='d')
plt.title("Confusion Matrix - DeiT (224x224)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "DeiT_confusion_matrix.png"), dpi=150)
plt.close()

# Training curves
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Val Loss', marker='s')
plt.legend()
plt.title("Loss Curves")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Acc', marker='o')
plt.plot(val_accuracies, label='Val Acc', marker='s')
plt.legend()
plt.title("Accuracy Curves")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "DeiT_training_curves.png"), dpi=150)
plt.close()

# Model statistics
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
model_size_mb = total_params * 4 / (1024 ** 2)
avg_time_per_image = np.mean(inference_times) / BATCH_SIZE
total_inference_time = sum(inference_times)
throughput = len(test_dataset) / total_inference_time

summary = {
    "Model": "DeiT Small (224x224, 1-Channel, Focal Loss)",
    "Accuracy": accuracy,
    "Macro Precision": precision_avg,
    "Macro Recall": recall_avg,
    "Macro F1": f1_avg,
    "Best Val Accuracy": best_val_acc,
    "Total Params": total_params,
    "Trainable Params": trainable_params,
    "Model Size (MB)": model_size_mb,
    "Avg Inference Time (ms)": avg_time_per_image * 1000,
    "Throughput (img/sec)": throughput,
    "Hardware": str(DEVICE) + (f" ({torch.cuda.get_device_name(0)})" if torch.cuda.is_available() else "")
}
pd.DataFrame([summary]).to_csv(os.path.join(OUTPUT_DIR, "DeiT_summary.csv"), index=False)

print("RESULTS SUMMARY")
print("="*60)
print(f"Test Accuracy: {accuracy*100:.2f}%")
print(f"Best Val Accuracy: {best_val_acc:.2f}%")
print(f"Macro Precision: {precision_avg:.4f}")
print(f"Macro Recall: {recall_avg:.4f}")
print(f"Macro F1-Score: {f1_avg:.4f}")
print(f"Total Parameters: {total_params:,}")
print(f"Model Size: {model_size_mb:.2f} MB")
print(f"Avg Inference Time: {avg_time_per_image*1000:.2f} ms/image")
print(f"Throughput: {throughput:.2f} images/sec")
print("\n DeiT training complete! All results saved to /kaggle/working/")


VALIDASI DATA
Number of classes: 25
Train samples: 7471
Test samples: 1868

Testing forward pass...
Input shape: torch.Size([16, 1, 224, 224])
Output shape: torch.Size([16, 25])
Expected output shape: [16, 25]

STARTING TRAINING
Epoch 1/20 - Train Loss: 0.2550, Train Acc: 86.19%, Val Loss: 0.0460, Val Acc: 89.94%
Epoch 2/20 - Train Loss: 0.0547, Train Acc: 93.67%, Val Loss: 0.0248, Val Acc: 98.07%
Epoch 3/20 - Train Loss: 0.0409, Train Acc: 95.02%, Val Loss: 0.0680, Val Acc: 90.10%
Epoch 4/20 - Train Loss: 0.0315, Train Acc: 96.29%, Val Loss: 0.0159, Val Acc: 99.14%
Epoch 5/20 - Train Loss: 0.0250, Train Acc: 97.48%, Val Loss: 0.0127, Val Acc: 99.30%
Epoch 6/20 - Train Loss: 0.0183, Train Acc: 98.42%, Val Loss: 0.0167, Val Acc: 99.14%
Epoch 7/20 - Train Loss: 0.0181, Train Acc: 98.23%, Val Loss: 0.0167, Val Acc: 99.09%
Epoch 8/20 - Train Loss: 0.0124, Train Acc: 98.47%, Val Loss: 0.0126, Val Acc: 99.52%
Epoch 9/20 - Train Loss: 0.0172, Train Acc: 98.34%, Val Loss: 0.0175, Val Acc: 99.