In [14]:
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from dvclive import Live
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image, ImageFile
from tqdm import tqdm
import yaml

ImageFile.LOAD_TRUNCATED_IMAGES = True

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Device: {device}")


Device: mps


In [15]:
with open('../params.yaml', 'r') as f:
    params = yaml.safe_load(f)

print("Parameters:")
print(yaml.dump(params, default_flow_style=False))


Parameters:
augmentation:
  brightness: 0.2
  contrast: 0.2
  crop_size: 256
  resize: 256
data:
  random_state: 42
  train_test_split: 0.2
model:
  name: resnet50
  num_classes: 5
  pretrained: IMAGENET1K_V1
scheduler:
  gamma: 0.1
  step_size: 5
training:
  batch_size: 32
  learning_rate: 0.001
  num_epochs: 20
  optimizer: Adam



In [16]:
df = pd.read_csv('../data/balanced_animals_dataset.csv')
print(f"Dataset shape: {df.shape}")


Dataset shape: (12000, 6)


In [17]:
label_to_idx = {label: idx for idx, label in enumerate(df['scientific_name'].unique())}
idx_to_label = {idx: label for label, idx in label_to_idx.items()}
df['label'] = df['scientific_name'].map(label_to_idx)


In [18]:
train_df, val_df = train_test_split(
    df, 
    test_size=params['data']['train_test_split'], 
    stratify=df['label'], 
    random_state=params['data']['random_state']
)

print(f"Train: {len(train_df)}, Val: {len(val_df)}")


Train: 9600, Val: 2400


In [19]:
class AnimalDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        self.base_path = Path('../animal_images')
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        species = row['scientific_name'].replace(' ', '_')
        img_path = self.base_path / species / f"{row['uuid']}.jpg"
        
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception:
            image = Image.new('RGB', (224, 224), color=(0, 0, 0))
            
        label = row['label']
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


In [20]:
train_transform = transforms.Compose([
    transforms.Resize(params['augmentation']['resize']),
    transforms.RandomCrop(params['augmentation']['crop_size']),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(
        brightness=params['augmentation']['brightness'],
        contrast=params['augmentation']['contrast'],
        saturation=params['augmentation']['contrast']
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(params['augmentation']['resize']),
    transforms.CenterCrop(params['augmentation']['crop_size']),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [21]:
train_dataset = AnimalDataset(train_df, transform=train_transform)
val_dataset = AnimalDataset(val_df, transform=val_transform)

batch_size = params['training']['batch_size']
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)


In [22]:
if params['model']['name'] == 'resnet50':
    weights = getattr(models.ResNet50_Weights, params['model']['pretrained'])
    model = models.resnet50(weights=weights)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, params['model']['num_classes'])
    
elif params['model']['name'] == 'resnet101':
    weights = getattr(models.ResNet101_Weights, params['model']['pretrained'])
    model = models.resnet101(weights=weights)
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, params['model']['num_classes'])
    
elif params['model']['name'] == 'efficientnet_v2_m':
    weights = getattr(models.EfficientNet_V2_M_Weights, params['model']['pretrained'])
    model = models.efficientnet_v2_m(weights=weights)
    num_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_features, params['model']['num_classes'])
    
elif params['model']['name'] == 'convnext_base':
    weights = getattr(models.ConvNeXt_Base_Weights, params['model']['pretrained'])
    model = models.convnext_base(weights=weights)
    num_features = model.classifier[2].in_features
    model.classifier[2] = nn.Linear(num_features, params['model']['num_classes'])
    
elif params['model']['name'] == 'vit_b_16':
    weights = getattr(models.ViT_B_16_Weights, params['model']['pretrained'])
    model = models.vit_b_16(weights=weights)
    num_features = model.heads.head.in_features
    model.heads.head = nn.Linear(num_features, params['model']['num_classes'])

for param in model.parameters():
    param.requires_grad = False

# Разморозить последний слой
if 'efficientnet' in params['model']['name']:
    for param in model.classifier.parameters():
        param.requires_grad = True
elif 'convnext' in params['model']['name']:
    for param in model.classifier.parameters():
        param.requires_grad = True
elif 'vit' in params['model']['name']:
    for param in model.heads.parameters():
        param.requires_grad = True
else:
    for param in model.fc.parameters():
        param.requires_grad = True

model = model.to(device)
print(f"Model: {params['model']['name']}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


Model: resnet50
Trainable parameters: 10,245


In [23]:
criterion = nn.CrossEntropyLoss()

# Получить параметры для обучения в зависимости от модели
if 'efficientnet' in params['model']['name']:
    trainable_params = model.classifier.parameters()
elif 'convnext' in params['model']['name']:
    trainable_params = model.classifier.parameters()
elif 'vit' in params['model']['name']:
    trainable_params = model.heads.parameters()
else:
    trainable_params = model.fc.parameters()

optimizer = optim.Adam(trainable_params, lr=params['training']['learning_rate'])
scheduler = optim.lr_scheduler.StepLR(
    optimizer, 
    step_size=params['scheduler']['step_size'], 
    gamma=params['scheduler']['gamma']
)


In [24]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total


In [25]:
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(loader), 100. * correct / total, np.array(all_preds), np.array(all_labels)


In [26]:
num_epochs = params['training']['num_epochs']
best_val_acc = 0.0
Path('../models').mkdir(exist_ok=True)

with Live(dir='../dvclive', save_dvc_exp=True) as live:
    
    live.log_params(params)
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        val_loss, val_acc, val_preds, val_labels = validate(
            model, val_loader, criterion, device
        )
        
        live.log_metric('train/loss', train_loss)
        live.log_metric('train/accuracy', train_acc)
        live.log_metric('val/loss', val_loss)
        live.log_metric('val/accuracy', val_acc)
        live.next_step()
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'label_to_idx': label_to_idx,
                'idx_to_label': idx_to_label,
                'params': params
            }, '../models/best_model.pth')
            print(f"Best model saved: {val_acc:.2f}%")
        
        scheduler.step()
    
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    
    cm = confusion_matrix(val_labels, val_preds)
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=[idx_to_label[i] for i in range(len(idx_to_label))],
        yticklabels=[idx_to_label[i] for i in range(len(idx_to_label))],
        ax=ax
    )
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    ax.set_title('Confusion Matrix')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    confusion_matrix_path = '../dvclive/plots/confusion_matrix.png'
    Path(confusion_matrix_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(confusion_matrix_path, dpi=150, bbox_inches='tight')
    live.log_image('confusion_matrix.png', confusion_matrix_path)
    plt.close()
    
    live.log_metric('best_val_accuracy', best_val_acc)



Epoch 1/20


Training: 100%|██████████| 300/300 [01:29<00:00,  3.35it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.18it/s]


Train Loss: 1.1636, Train Acc: 54.07%
Val Loss: 0.9429, Val Acc: 63.04%
Best model saved: 63.04%

Epoch 2/20


Training: 100%|██████████| 300/300 [01:25<00:00,  3.49it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.32it/s]


Train Loss: 0.9890, Train Acc: 62.26%
Val Loss: 0.9447, Val Acc: 64.92%
Best model saved: 64.92%

Epoch 3/20


Training: 100%|██████████| 300/300 [01:24<00:00,  3.56it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.29it/s]


Train Loss: 0.9542, Train Acc: 63.24%
Val Loss: 0.9301, Val Acc: 65.33%
Best model saved: 65.33%

Epoch 4/20


Training: 100%|██████████| 300/300 [01:23<00:00,  3.58it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.28it/s]


Train Loss: 0.9495, Train Acc: 63.54%
Val Loss: 0.9090, Val Acc: 66.17%
Best model saved: 66.17%

Epoch 5/20


Training: 100%|██████████| 300/300 [01:23<00:00,  3.57it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.30it/s]


Train Loss: 0.9405, Train Acc: 64.27%
Val Loss: 0.9362, Val Acc: 64.75%

Epoch 6/20


Training: 100%|██████████| 300/300 [01:24<00:00,  3.54it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.24it/s]


Train Loss: 0.8585, Train Acc: 67.33%
Val Loss: 0.8811, Val Acc: 67.46%
Best model saved: 67.46%

Epoch 7/20


Training: 100%|██████████| 300/300 [01:23<00:00,  3.58it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.31it/s]


Train Loss: 0.8630, Train Acc: 66.84%
Val Loss: 0.8743, Val Acc: 67.58%
Best model saved: 67.58%

Epoch 8/20


Training: 100%|██████████| 300/300 [01:23<00:00,  3.59it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.34it/s]


Train Loss: 0.8536, Train Acc: 67.76%
Val Loss: 0.8687, Val Acc: 68.00%
Best model saved: 68.00%

Epoch 9/20


Training: 100%|██████████| 300/300 [01:23<00:00,  3.59it/s]
Validation: 100%|██████████| 75/75 [00:17<00:00,  4.32it/s]


Train Loss: 0.8619, Train Acc: 66.96%
Val Loss: 0.8694, Val Acc: 66.96%

Epoch 10/20


Training: 100%|██████████| 300/300 [01:24<00:00,  3.53it/s]
Validation: 100%|██████████| 75/75 [00:18<00:00,  4.14it/s]


Train Loss: 0.8560, Train Acc: 67.64%
Val Loss: 0.8783, Val Acc: 67.33%

Epoch 11/20


Training:  77%|███████▋  | 231/300 [01:07<00:20,  3.42it/s]


KeyboardInterrupt: 