In [4]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor

class HardDistillationLoss(nn.Module):
    def __init__(self, teacher: nn.Module):
        super().__init__()
        self.teacher = teacher
        self.criterion = nn.CrossEntropyLoss()  # Работает с любым количеством классов
        
    def forward(self, inputs: Tensor, outputs: tuple[Tensor, Tensor], labels: Tensor) -> Tensor:
        outputs_cls, outputs_dist = outputs
        
        # Базовая потеря (CLS)
        base_loss = self.criterion(outputs_cls, labels)

        # Вычисляем предсказания учителя
        with torch.no_grad():
            teacher_outputs = self.teacher(inputs)
        teacher_labels = torch.argmax(teacher_outputs, dim=1)  # Теперь метки 0, 1, 2

        # Потеря для DIST
        teacher_loss = self.criterion(outputs_dist, teacher_labels)

        return 0.5 * base_loss + 0.5 * teacher_loss

In [6]:
from typing import Union

In [8]:
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        super().__init__()
        self.patch_size = patch_size

        # Проекция патчей
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )

        # Токены CLS и DIST
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.dist_token = nn.Parameter(torch.randn(1, 1, emb_size))  # Убедитесь, что это определено

        # Позиционные эмбеддинги
        num_patches = (img_size // patch_size) ** 2
        self.positions = nn.Parameter(torch.randn(num_patches + 2, emb_size))  # +2 для cls_token и dist_token

    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape

        # Проекция патчей
        x = self.projection(x)

        # Создание токенов CLS и DIST
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        dist_tokens = repeat(self.dist_token, '() n e -> b n e', b=b)

        # Добавление токенов CLS и DIST к входным данным
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)

        # Добавление позиционных эмбеддингов
        x += self.positions

        return x

In [9]:
class ClassificationHead(nn.Module):
    def __init__(self, emb_size: int = 768, n_classes: int = 2):       
        super().__init__()

        self.head = nn.Linear(emb_size, n_classes)
        self.dist_head = nn.Linear(emb_size, n_classes)

    def forward(self, x: Tensor) -> Tensor:
        x, x_dist = x[:, 0], x[:, 1]
        x_head = self.head(x)
        x_dist_head = self.dist_head(x_dist)
        
        if self.training:
            x = x_head, x_dist_head  # Возвращает кортеж
        else:
            x = (x_head + x_dist_head) / 2  # Возвращает тензор
        return x

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
    
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )
        
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

In [14]:
class DeiT(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                **kwargs):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, **kwargs),
            ClassificationHead(emb_size, n_classes))

In [20]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Определение преобразований для изображений
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Изменяем размер до 224x224
    transforms.ToTensor(),          # Преобразуем в тензор
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Нормализация
])

# Создание датасета с помощью ImageFolder
ds = datasets.ImageFolder(root='TestingAndTrainingFinal_min', transform=transform)

# Создание DataLoader
dl = DataLoader(ds, batch_size=32, shuffle=False)

print(ds.classes)  # ['Anomaly', 'Brain_tumor', 'Health', ]
print(len(ds))

['Anomaly', 'Brain_tumor', 'Health']
822


In [22]:
from torch.optim import Adam #
import timm
from tqdm import tqdm

# Teacher model (Vision Transformer)
teacher = timm.create_model('vit_large_patch16_224', pretrained=True, num_classes=3)
teacher.eval()

# Student model (DeiT)
student = timm.create_model('deit_small_patch16_224', pretrained=True, num_classes=3)

# teacher = ViT.vit_large_patch16_224()
# student = DeiT.deit_small_patch16_224()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Создание датасета

# ds = datasets.ImageFolder(
#     root='archive-4',
#     transform=transform,
#     target_transform=lambda x: 0 if ds.classes[x] == 'tumor' else 1
# )
ds = datasets.ImageFolder(root='TestingAndTrainingFinal_min', transform=transform)

# Создание DataLoader
dl = DataLoader(ds, batch_size=32, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher.to(device)

dummy_input = torch.randn(1, 3, 224, 224).to(device)
output = teacher(dummy_input)
print(output.shape)  # Должно быть torch.Size([1, 2])
print("Выход учителя:", teacher(dummy_input).shape)  # torch.Size([1, 3])

# Для студента
print("Выход студента:", student(dummy_input))

student = DeiT(
    in_channels=3,
    patch_size=16,
    emb_size=384,
    img_size=224,
    depth=12,
    n_classes=3  # Два класса
)

# Оптимизатор
optimizer = Adam(student.parameters(), lr=0.001)

# Функция потерь
criterion = HardDistillationLoss(teacher)

# Цикл обучения

teacher.to(device)
student.to(device)

try:
    for epoch in range(10):
        student.train()
        running_loss = 0.0
        for inputs, labels in tqdm(dl, desc=f"Epoch {epoch+1}"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = student(inputs)  # Должен вернуть (outputs_cls, outputs_dist)
            loss = criterion(inputs, outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(dl):.4f}")
except Exception as e:
    print("Ошибка:", e)

1
torch.Size([1, 3])
Выход учителя: torch.Size([1, 3])
Выход студента: tensor([[-0.1888,  0.2712,  0.4434]], grad_fn=<AddmmBackward0>)


Epoch 1: 100%|██████████████████████████████████| 26/26 [05:33<00:00, 12.83s/it]


Epoch 1, Loss: 2.4712


Epoch 2: 100%|██████████████████████████████████| 26/26 [07:13<00:00, 16.66s/it]


Epoch 2, Loss: 0.8160


Epoch 3: 100%|██████████████████████████████████| 26/26 [06:51<00:00, 15.83s/it]


Epoch 3, Loss: 0.7778


Epoch 4: 100%|██████████████████████████████████| 26/26 [05:54<00:00, 13.64s/it]


Epoch 4, Loss: 0.6969


Epoch 5: 100%|██████████████████████████████████| 26/26 [06:29<00:00, 14.98s/it]


Epoch 5, Loss: 0.7143


Epoch 6: 100%|██████████████████████████████████| 26/26 [05:49<00:00, 13.43s/it]


Epoch 6, Loss: 0.6611


Epoch 7: 100%|██████████████████████████████████| 26/26 [05:34<00:00, 12.85s/it]


Epoch 7, Loss: 0.6236


Epoch 8: 100%|██████████████████████████████████| 26/26 [05:21<00:00, 12.37s/it]


Epoch 8, Loss: 0.5965


Epoch 9: 100%|██████████████████████████████████| 26/26 [05:17<00:00, 12.21s/it]


Epoch 9, Loss: 0.7062


Epoch 10: 100%|█████████████████████████████████| 26/26 [05:10<00:00, 11.96s/it]

Epoch 10, Loss: 0.6711
3





In [33]:
dummy_input = torch.randn(2, 3, 224, 224)  # Batch size = 2
output = student(dummy_input)
print(output)

(tensor([[-0.5891, -0.1246],
        [-0.5992, -0.1299]], grad_fn=<AddmmBackward0>), tensor([[ 3.9488e-07,  6.1619e-01],
        [-2.4278e-03,  6.2632e-01]], grad_fn=<AddmmBackward0>))


In [82]:
unique_labels = set()
for _, labels in dl:
    unique_labels.update(labels.tolist())
print("Unique labels:", unique_labels)



Unique labels: {1}


In [56]:
for _, labels in dl:
    print("Метки в батче:", labels.unique())
    break

Метки в батче: tensor([0, 1, 2])
