In [8]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch import Tensor
from typing import Union, Tuple
import numpy as np
from scipy.stats import entropy
from sklearn.ensemble import RandomForestClassifier
import joblib
import matplotlib.pyplot as plt
from einops import rearrange, repeat
from tqdm import tqdm
import cv2
import os
import timm
from torchvision import datasets, transforms
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from einops.layers.torch import Rearrange

In [9]:
#–æ–ø—Ä–µ–¥–µ–ª–µ–Ω–∏–µ DeiT –º–æ–¥–µ–ª–∏
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 384, img_size: int = 224):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e h w -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.dist_token = nn.Parameter(torch.randn(1, 1, emb_size))
        num_patches = (img_size // patch_size) ** 2
        self.positions = nn.Parameter(torch.randn(num_patches + 2, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        dist_tokens = repeat(self.dist_token, '() n e -> b n e', b=b)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
        x += self.positions
        return x

In [10]:
class ClassificationHead(nn.Module):
    def __init__(self, emb_size: int = 384, n_classes: int = 2):
        super().__init__()
        self.head = nn.Linear(emb_size, n_classes)
        self.dist_head = nn.Linear(emb_size, n_classes)

    def forward(self, x: Tensor) -> Tensor:
        x_cls, x_dist = x[:, 0], x[:, 1]
        x_head = self.head(x_cls)
        x_dist_head = self.dist_head(x_dist)
        if self.training:
            return x_head, x_dist_head
        else:
            return (x_head + x_dist_head) / 2


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 384, num_heads: int = 6, dropout: float = 0.):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self, emb_size: int = 384, num_heads: int = 6, drop_p: float = 0., forward_expansion: int = 4,
                 forward_drop_p: float = 0.):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            ))
        )


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[
            TransformerEncoderBlock(**kwargs)
            for _ in range(depth)
        ])

In [11]:
class DeiT(nn.Sequential):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 384,
                 img_size: int = 224, depth: int = 12, n_classes: int = 2):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size),
            ClassificationHead(emb_size, n_classes)
        )

In [12]:
#–æ–ø—Ä–µ–¥–µ–ª–µ–Ω–∏–µ Loss –∏ Grad-CAM
class HardDistillationLoss(nn.Module):
    def __init__(self, teacher: nn.Module):
        super().__init__()
        self.teacher = teacher
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, inputs: Tensor, outputs: tuple[Tensor, Tensor], labels: Tensor) -> Tensor:
        outputs_cls, outputs_dist = outputs
        base_loss = self.criterion(outputs_cls, labels)

        with torch.no_grad():
            teacher_outputs = self.teacher(inputs)
        teacher_logits = teacher_outputs[:, :2]
        teacher_labels = torch.argmax(teacher_logits, dim=1)

        teacher_loss = self.criterion(outputs_dist, teacher_labels)

        return 0.5 * base_loss + 0.5 * teacher_loss

In [13]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        target_layer.register_forward_hook(self.save_activations)
        target_layer.register_backward_hook(self.save_gradients)

    def save_activations(self, module, input, output):
        self.activations = output.detach()

    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def forward(self, x, class_idx=None):
        original_size = x.shape[-2:]
        h, w = original_size
        logits = self.model(x)
        if isinstance(logits, tuple):
            logits = logits[0]
        self.model.zero_grad()
        if class_idx is None:
            class_idx = logits.argmax(dim=1).item()
        one_hot = torch.zeros_like(logits)
        one_hot[0][class_idx] = 1
        one_hot.requires_grad_(True)
        output = (one_hot * logits).sum()
        output.backward(retain_graph=True)
        gradients = self.gradients.cpu().numpy()[0]
        activations = self.activations.cpu().numpy()[0]
        weights = np.mean(gradients, axis=(1, 2))
        cam = np.zeros(activations.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, (int(w), int(h)))
        cam = cam - np.min(cam)
        cam = cam / np.max(cam)
        return cam

    def __call__(self, x, class_idx=None):
        return self.forward(x, class_idx)

In [None]:
# –ø–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –¥–∞–Ω–Ω—ã—Ö –∏ –º–æ–¥–µ–ª–µ–π
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

ds = datasets.ImageFolder(root='Testing', transform=transform)
dl = DataLoader(ds, batch_size=32, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher = timm.create_model('vit_large_patch16_224', pretrained=True, num_classes=2)
student = DeiT(n_classes=2)

optimizer = optim.Adam(student.parameters(), lr=0.001)
criterion = HardDistillationLoss(teacher)

teacher.to(device)
student.to(device)

# ---------------------------
# üèãÔ∏è‚Äç‚ôÇÔ∏è 4. –û–±—É—á–µ–Ω–∏–µ
# ---------------------------

train_losses = []
train_accuracies = []
train_f1_scores = []
train_auc_scores = []

try:
    for epoch in range(5):
        student.train()
        running_loss = 0.0
        all_labels, all_preds, all_probs = [], [], []

        for batch in tqdm(dl, desc=f"Epoch {epoch+1}/5"):
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = student(inputs)
            if isinstance(outputs, tuple):
                outputs_cls, outputs_dist = outputs
            else:
                outputs_cls = outputs
                outputs_dist = outputs
            loss = criterion(inputs, (outputs_cls, outputs_dist), labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            probs = torch.softmax(outputs_cls, dim=1).detach().cpu().numpy()
            preds = probs.argmax(axis=1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)
            all_probs.extend(probs[:, 1])

        epoch_loss = running_loss / len(dl)
        epoch_acc = accuracy_score(all_labels, all_preds)
        epoch_f1 = f1_score(all_labels, all_preds)
        epoch_auc = roc_auc_score(all_labels, all_probs)

        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        train_f1_scores.append(epoch_f1)
        train_auc_scores.append(epoch_auc)

        print(f"Epoch {epoch+1}/5 | Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}, F1: {epoch_f1:.4f}, AUC: {epoch_auc:.4f}")

except Exception as e:
    print("–û—à–∏–±–∫–∞:", e)

Epoch 1/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 193/193 [44:37<00:00, 13.87s/it]


Epoch 1/5 | Loss: 0.8754, Acc: 0.5351, F1: 0.4813, AUC: 0.5473


Epoch 2/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 193/193 [41:38<00:00, 12.95s/it]


Epoch 2/5 | Loss: 0.7089, Acc: 0.5668, F1: 0.5091, AUC: 0.5808


Epoch 3/5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 193/193 [47:31<00:00, 14.77s/it]


Epoch 3/5 | Loss: 0.7032, Acc: 0.5723, F1: 0.5180, AUC: 0.5953


Epoch 4/5:  26%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà                       | 50/193 [14:02<39:42, 16.66s/it]

In [None]:
# ---------------------------
# ü§ñ 5. –ú–µ—Ç–∞–º–æ–¥–µ–ª—å —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç–∏
# ---------------------------

X_meta = []
y_meta = []

student.eval()
with torch.no_grad():
    for images, labels in dl:
        images = images.to(device)
        outputs = student(images)
        probs = torch.softmax(outputs, dim=1).cpu().numpy()
        preds = np.argmax(probs, axis=1)
        true_labels = labels.numpy()

        for i in range(len(true_labels)):
            p0, p1 = probs[i]
            max_p = max(p0, p1)
            entr = entropy([p0, p1])
            X_meta.append([p0, p1, max_p, entr])
            y_meta.append(1 if preds[i] == true_labels[i] else 0)

meta_model = RandomForestClassifier(n_estimators=100, random_state=42)
meta_model.fit(X_meta, y_meta)
joblib.dump(meta_model, "meta_model.pkl")
print("–ú–µ—Ç–∞–º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –∫–∞–∫ meta_model.pkl")


def classify(image_tensor):
    student.eval()
    with torch.no_grad():
        output = student(image_tensor.unsqueeze(0).to(device))
        probs = torch.softmax(output, dim=1).cpu().numpy()[0]

    p0, p1 = probs
    entr = entropy([p0, p1])
    max_p = max(p0, p1)
    features = [[p0, p1, max_p, entr]]
    trust = meta_model.predict(features)[0]

    if trust == 1:
        return "–∑–¥–æ—Ä–æ–≤" if p0 > p1 else "–±–æ–ª–µ–Ω"
    else:
        return "–∞–Ω–æ–º–∞–ª—å–Ω–æ–µ"

In [None]:
image, label = ds[0]
result = classify(image)
print("–†–µ–∑—É–ª—å—Ç–∞—Ç –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏:", result)