In [1]:
import os
import time

import kagglehub
import numpy as np
from tqdm import tqdm

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

import torch
from torch import nn, optim
from torch.amp import autocast
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

from sklearn.metrics import precision_score, recall_score, f1_score

In [2]:
path = kagglehub.dataset_download("ifigotin/imagenetmini-1000")
path += "/imagenet-mini"

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/imagenetmini-1000/imagenet-mini


In [3]:
teacher_model = timm.create_model('efficientvit_b3.r256_in1k', pretrained=True)
student_model = timm.create_model('efficientvit_b1.r256_in1k', pretrained=False)

model.safetensors:   0%|          | 0.00/195M [00:00<?, ?B/s]

In [4]:
def load_imagenet_mini(dataset_path, model, split, transforms_list):
    config = resolve_data_config({}, model=model)
    transform = create_transform(**config)

    print(transform)

    if len(transforms_list):
        transform = transforms.Compose([
            *transforms_list,
            transform,
            # transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.33))
        ])

    dataset = datasets.ImageFolder(
        root=os.path.join(dataset_path, split),
        transform=transform
    )
    
    return dataset

In [5]:
transforms_list = [
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
    transforms.RandomRotation(degrees=15),
]

original_dataset = load_imagenet_mini(path, teacher_model, 'train', transforms_list)

train_dataset, val_dataset = random_split(original_dataset, [0.8, 0.2])
test_dataset = load_imagenet_mini(path, teacher_model, 'val', [])

Compose(
    Resize(size=256, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(256, 256))
    MaybeToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
Compose(
    Resize(size=256, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(256, 256))
    MaybeToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)


In [6]:
batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [7]:
def train_one_epoch(
    student_model: nn.Module,
    teacher_model: nn.Module,
    train_loader: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    distill_criterion: nn.Module,
    alpha: float,
    device: str,
):
    student_model.train()
    teacher_model.eval()

    loss_accum = 0

    for img_batch, targets in tqdm(train_loader, desc='Training model'):
        optimizer.zero_grad()

        img_batch = img_batch.to(device)
        targets = targets.to(device)

        student_output = student_model(img_batch)

        with torch.no_grad():
            teacher_output = teacher_model(img_batch)

        loss = alpha * criterion(student_output, targets) + (1 - alpha) * distill_criterion(student_output, teacher_output)
        loss.backward()

        optimizer.step()

        loss_accum += loss.item()

    return loss_accum / len(train_loader)


def validate_model(
    student_model: nn.Module,
    teacher_model: nn.Module,
    val_loader: DataLoader,
    criterion: nn.Module,
    distill_criterion: nn.Module,
    alpha: float,
    device: str,
):
    predictions = []
    ground_truth = []
    loss_accum = 0

    student_model.eval()
    teacher_model.eval()

    for img_batch, targets in tqdm(val_loader, desc='Validating model'):
        img_batch = img_batch.to(device)
        targets = targets.to(device)
        
        with torch.no_grad():
            student_output = student_model(img_batch)
            teacher_output = teacher_model(img_batch)
            
            loss = alpha * criterion(student_output, targets) + (1 - alpha) * distill_criterion(student_output, teacher_output)

        loss_accum += loss.item()
        
        preds = student_output.argmax(dim=1).cpu().numpy()
        predictions.extend(preds)
        ground_truth.extend(targets.cpu().numpy().tolist())

    return loss_accum / len(val_loader), f1_score(ground_truth, predictions, average='macro')

In [8]:
def distillation_loss(student_logits, teacher_logits, T):
    soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)
    soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)

    return torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)

In [9]:
epoch_num = 25
learning_rate = 1e-3
alpha = 0.5
distill_temp = 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

optimizer = optim.AdamW(student_model.parameters(), lr=learning_rate, weight_decay=0.05)
class_loss = nn.CrossEntropyLoss()
distill_loss = lambda x, y: distillation_loss(x, y, distill_temp)
best_score = 0

student_model.to(device)
teacher_model.to(device)

for epoch_ind in range(1, epoch_num + 1):
    train_loss = train_one_epoch(student_model, teacher_model, train_loader, optimizer, class_loss, distill_loss, alpha, device)
    val_loss, f1_score_val = validate_model(student_model, teacher_model, val_loader, class_loss, distill_loss, alpha, device)

    print(f'Epoch #{epoch_ind} Train Loss: {round(train_loss, 5)} Val Loss: {round(val_loss, 5)} F1: {round(f1_score_val, 5)}')

    if f1_score_val > best_score:
        best_score = f1_score_val
        torch.save(student_model.state_dict(), 'best_model.pt')

Training model: 100%|██████████| 435/435 [03:48<00:00,  1.91it/s]
Validating model: 100%|██████████| 109/109 [00:52<00:00,  2.06it/s]


Epoch #1 Train Loss: 5.39366 Val Loss: 5.31047 F1: 3e-05


Training model: 100%|██████████| 435/435 [03:39<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:46<00:00,  2.34it/s]


Epoch #2 Train Loss: 5.29791 Val Loss: 5.26198 F1: 6e-05


Training model: 100%|██████████| 435/435 [03:39<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:46<00:00,  2.32it/s]


Epoch #3 Train Loss: 5.22014 Val Loss: 5.1828 F1: 0.00035


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.97it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.29it/s]


Epoch #4 Train Loss: 5.12677 Val Loss: 5.07856 F1: 0.00084


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.30it/s]


Epoch #5 Train Loss: 5.04272 Val Loss: 5.05012 F1: 0.00124


Training model: 100%|██████████| 435/435 [03:39<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:46<00:00,  2.33it/s]


Epoch #6 Train Loss: 4.98363 Val Loss: 4.94973 F1: 0.00268


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.32it/s]


Epoch #7 Train Loss: 4.92037 Val Loss: 4.92982 F1: 0.0027


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.31it/s]


Epoch #8 Train Loss: 4.84192 Val Loss: 4.85249 F1: 0.00543


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.97it/s]
Validating model: 100%|██████████| 109/109 [00:49<00:00,  2.22it/s]


Epoch #9 Train Loss: 4.75874 Val Loss: 4.80423 F1: 0.0067


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:46<00:00,  2.32it/s]


Epoch #10 Train Loss: 4.70039 Val Loss: 4.7483 F1: 0.00806


Training model: 100%|██████████| 435/435 [03:39<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.32it/s]


Epoch #11 Train Loss: 4.60952 Val Loss: 4.70739 F1: 0.01115


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:48<00:00,  2.25it/s]


Epoch #12 Train Loss: 4.53295 Val Loss: 4.62155 F1: 0.01513


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.97it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.28it/s]


Epoch #13 Train Loss: 4.45792 Val Loss: 4.56707 F1: 0.01826


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.29it/s]


Epoch #14 Train Loss: 4.38696 Val Loss: 4.50804 F1: 0.02124


Training model: 100%|██████████| 435/435 [03:39<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.30it/s]


Epoch #15 Train Loss: 4.3688 Val Loss: 4.50769 F1: 0.02892


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.31it/s]


Epoch #16 Train Loss: 4.28188 Val Loss: 4.46019 F1: 0.02773


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:48<00:00,  2.26it/s]


Epoch #17 Train Loss: 4.21237 Val Loss: 4.44189 F1: 0.03377


Training model: 100%|██████████| 435/435 [03:39<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.27it/s]


Epoch #18 Train Loss: 4.15069 Val Loss: 4.41227 F1: 0.03134


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.29it/s]


Epoch #19 Train Loss: 4.12558 Val Loss: 5.08351 F1: 0.00728


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.97it/s]
Validating model: 100%|██████████| 109/109 [00:48<00:00,  2.25it/s]


Epoch #20 Train Loss: 4.3122 Val Loss: 4.41818 F1: 0.03423


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.97it/s]
Validating model: 100%|██████████| 109/109 [00:49<00:00,  2.21it/s]


Epoch #21 Train Loss: 4.13712 Val Loss: 4.42311 F1: 0.04167


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.97it/s]
Validating model: 100%|██████████| 109/109 [00:49<00:00,  2.21it/s]


Epoch #22 Train Loss: 4.01367 Val Loss: 4.35647 F1: 0.04664


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.29it/s]


Epoch #23 Train Loss: 3.93953 Val Loss: 4.32902 F1: 0.04634


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.98it/s]
Validating model: 100%|██████████| 109/109 [00:47<00:00,  2.29it/s]


Epoch #24 Train Loss: 3.87795 Val Loss: 4.28042 F1: 0.05551


Training model: 100%|██████████| 435/435 [03:40<00:00,  1.97it/s]
Validating model: 100%|██████████| 109/109 [00:48<00:00,  2.26it/s]

Epoch #25 Train Loss: 3.81884 Val Loss: 4.26605 F1: 0.05509





## Compare models

In [10]:
!pip install gputil

Collecting gputil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-py3-none-any.whl size=7392 sha256=721c903f411efd0394e208374c89e7c9dc5f0c94e20e8f6c14891dc4d99995e5
  Stored in directory: /root/.cache/pip/wheels/2b/4d/8f/55fb4f7b9b591891e8d3f72977c4ec6c7763b39c19f0861595
Successfully built gputil
Installing collected packages: gputil
Successfully installed gputil-1.4.0


In [11]:
import torch
import time
import timm
import psutil
import os
import GPUtil
import numpy as np
from torch.amp import autocast
from torchvision import datasets, transforms
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from sklearn.metrics import precision_score, recall_score, f1_score

def print_memory_usage(label=""):
    """Выводит использование памяти в мегабайтах"""
    if label:
        print(f"\n--- Memory Usage ({label}) ---")
    else:
        print("\n--- Memory Usage ---")
        
    # CPU RAM в MB
    process = psutil.Process(os.getpid())
    ram_used = process.memory_info().rss / (1024 ** 2)
    print(f"CPU RAM used: {ram_used:.2f} MB")
    
    # GPU VRAM в MB
    gpus = GPUtil.getGPUs()
    for gpu in gpus:
        vram_used = gpu.memoryUsed
        vram_total = gpu.memoryTotal
        print(f"GPU {gpu.id} VRAM: {vram_used:.2f} MB / {vram_total:.2f} MB")

def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    return (param_size + buffer_size) / 1024**2

def calculate_metrics(model, device, data_loader):
    model.eval()
    all_preds = []
    all_targets = []

    model.to(device)
    with torch.no_grad():
        for images, targets in data_loader:
            images = images.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1).cpu().numpy()
            
            all_preds.extend(preds)
            all_targets.extend(targets.numpy())
    
    precision = precision_score(all_targets, all_preds, average='macro')
    recall = recall_score(all_targets, all_preds, average='macro')
    f1 = f1_score(all_targets, all_preds, average='macro')
    
    return precision, recall, f1

def load_imagenet_mini(dataset_path, model):
    # Создаем трансформы на основе модели
    config = resolve_data_config({}, model=model)
    transform = create_transform(**config)
    
    # Загружаем датасет
    dataset = datasets.ImageFolder(
        root=os.path.join(dataset_path, 'val'),
        transform=transform
    )
    
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=64,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    return data_loader

def benchmark_model(model, device, input_tensor, num_runs=10, warmup=3, use_amp=False):
    model = model.to(device)
    input_tensor = input_tensor.to(device)
    
    # Warmup
    print(f"\n🔥 Warming up ({warmup} runs) on {device}...")
    for _ in range(warmup):
        with torch.no_grad():
            if use_amp and device.type == 'cuda':
                with autocast(device_type='cuda', dtype=torch.float16):
                    _ = model(input_tensor)
            else:
                _ = model(input_tensor)
    
    # Benchmark
    print(f"🚀 Benchmarking ({num_runs} runs) on {device}...")
    start_time = time.time()
    
    for _ in range(num_runs):
        with torch.no_grad():
            if use_amp and device.type == 'cuda':
                with autocast(device_type='cuda', dtype=torch.float16):
                    _ = model(input_tensor)
            else:
                _ = model(input_tensor)
    
    total_time = (time.time() - start_time) * 1000
    avg_time = total_time / num_runs
    print(f"✅ Average inference: {avg_time:.2f} ms")
    print(f"📊 Total time: {total_time:.2f} ms | FPS: {1000/(avg_time + 1e-9):.1f}")
    
    return avg_time

def main(model, dataset_path=None):
    device_cpu = torch.device('cpu')
    device_gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print("\n🔍 Initial memory state:")
    print_memory_usage("Before loading model")
    model.eval()
    print(f"📏 Model size: {get_model_size(model):.2f} MB")

    print("\n🔍 Initial memory state:")
    print_memory_usage("Model loaded")
    
    # Бенчмарки
    input_tensor = torch.randn(1, 3, 224, 224)
    
    print("\n🧪 Benchmarking on CPU:")
    cpu_time = benchmark_model(model, device_cpu, input_tensor)
    print_memory_usage("After CPU benchmark")
    
    if torch.cuda.is_available():
        print("\n🎮 Benchmarking on GPU:")
        gpu_time = benchmark_model(model, device_gpu, input_tensor)
        print_memory_usage("After GPU test")
        
        print("\n⚡ Benchmarking with AMP:")
        gpu_amp_time = benchmark_model(model, device_gpu, input_tensor, use_amp=True)
        print_memory_usage("After AMP test")
        
        print("\n📈 Results Summary:")
        print(f"| Device | Inference Time (ms) | Speedup vs CPU |")
        print("|--------|---------------------|----------------|")
        print(f"| CPU    | {cpu_time:19.2f} | {'—':^15} |")
        print(f"| GPU    | {gpu_time:19.2f} | {cpu_time/gpu_time:^15.1f}x |")
        print(f"| AMP    | {gpu_amp_time:19.2f} | {cpu_time/gpu_amp_time:^15.1f}x |")
    else:
        print("\n❌ CUDA not available")
        print(f"⏱️ CPU inference time: {cpu_time:.2f} ms")
        if dataset_path:
            print("\n🎯 Quality Metrics (CPU):")
            print(f"Precision: {precision_cpu:.4f}")
            print(f"Recall:    {recall_cpu:.a4f}")
            print(f"F1-Score:  {f1_cpu:.4f}")

    # Загрузка и расчет метрик качества
    if dataset_path:
        print("\n📊 Loading ImageNetMini dataset...")
        data_loader = load_imagenet_mini(dataset_path, model)
        
        # print("\n🧮 Calculating metrics on CPU:")
        # precision_cpu, recall_cpu, f1_cpu = calculate_metrics(model, device_cpu, data_loader)
        
        if torch.cuda.is_available():
            print("\n🧮 Calculating metrics on GPU:")
            precision_gpu, recall_gpu, f1_gpu = calculate_metrics(model, device_gpu, data_loader)

            print("\n🎯 Quality Metrics Summary:")
            print("| Device | Precision | Recall  | F1-Score |")
            print("|--------|-----------|---------|----------|")
            # print(f"| CPU    | {precision_cpu:.4f}  | {recall_cpu:.4f} | {f1_cpu:.4f}  |")
            print(f"| GPU    | {precision_gpu:.4f}  | {recall_gpu:.4f} | {f1_gpu:.4f}  |")

In [12]:
base_model = timm.create_model('efficientvit_b1.r256_in1k', pretrained=True)

model.safetensors:   0%|          | 0.00/36.5M [00:00<?, ?B/s]

In [13]:
main(base_model, path)


🔍 Initial memory state:

--- Memory Usage (Before loading model) ---
CPU RAM used: 1925.69 MB
GPU 0 VRAM: 7747.00 MB / 16384.00 MB
📏 Model size: 34.77 MB

🔍 Initial memory state:

--- Memory Usage (Model loaded) ---
CPU RAM used: 1925.69 MB
GPU 0 VRAM: 7747.00 MB / 16384.00 MB

🧪 Benchmarking on CPU:

🔥 Warming up (3 runs) on cpu...
🚀 Benchmarking (10 runs) on cpu...
✅ Average inference: 32.09 ms
📊 Total time: 320.94 ms | FPS: 31.2

--- Memory Usage (After CPU benchmark) ---
CPU RAM used: 1934.82 MB
GPU 0 VRAM: 7747.00 MB / 16384.00 MB

🎮 Benchmarking on GPU:

🔥 Warming up (3 runs) on cuda...
🚀 Benchmarking (10 runs) on cuda...
✅ Average inference: 11.35 ms
📊 Total time: 113.53 ms | FPS: 88.1

--- Memory Usage (After GPU test) ---
CPU RAM used: 1935.82 MB
GPU 0 VRAM: 7753.00 MB / 16384.00 MB

⚡ Benchmarking with AMP:

🔥 Warming up (3 runs) on cuda...
🚀 Benchmarking (10 runs) on cuda...
✅ Average inference: 13.95 ms
📊 Total time: 139.51 ms | FPS: 71.7

--- Memory Usage (After AMP test)

  _warn_prf(average, modifier, msg_start, len(result))


In [14]:
trained_model = timm.create_model('efficientvit_b1.r256_in1k', pretrained=True)
trained_model.load_state_dict(torch.load('best_model.pt', weights_only=True))

<All keys matched successfully>

In [15]:
main(trained_model, path)


🔍 Initial memory state:

--- Memory Usage (Before loading model) ---
CPU RAM used: 1953.30 MB
GPU 0 VRAM: 7765.00 MB / 16384.00 MB
📏 Model size: 34.77 MB

🔍 Initial memory state:

--- Memory Usage (Model loaded) ---
CPU RAM used: 1953.30 MB
GPU 0 VRAM: 7765.00 MB / 16384.00 MB

🧪 Benchmarking on CPU:

🔥 Warming up (3 runs) on cpu...
🚀 Benchmarking (10 runs) on cpu...
✅ Average inference: 32.71 ms
📊 Total time: 327.14 ms | FPS: 30.6

--- Memory Usage (After CPU benchmark) ---
CPU RAM used: 1953.30 MB
GPU 0 VRAM: 7765.00 MB / 16384.00 MB

🎮 Benchmarking on GPU:

🔥 Warming up (3 runs) on cuda...
🚀 Benchmarking (10 runs) on cuda...
✅ Average inference: 10.86 ms
📊 Total time: 108.61 ms | FPS: 92.1

--- Memory Usage (After GPU test) ---
CPU RAM used: 1953.30 MB
GPU 0 VRAM: 7769.00 MB / 16384.00 MB

⚡ Benchmarking with AMP:

🔥 Warming up (3 runs) on cuda...
🚀 Benchmarking (10 runs) on cuda...
✅ Average inference: 13.69 ms
📊 Total time: 136.86 ms | FPS: 73.1

--- Memory Usage (After AMP test)

  _warn_prf(average, modifier, msg_start, len(result))
