In [1]:
import os
import torchvision.models as models
import numpy as np
import torch
import torchutils as tu
import torch.nn as nn
from torchvision.models import densenet121, DenseNet121_Weights
from torchvision.models import resnet50, DenseNet121_Weights

from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from PIL import Image

In [2]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [38]:

data = """
1 001.Black_footed_Albatross
2 002.Laysan_Albatross
3 003.Sooty_Albatross
4 004.Groove_billed_Ani
5 005.Crested_Auklet
6 006.Least_Auklet
7 007.Parakeet_Auklet
8 008.Rhinoceros_Auklet
9 009.Brewer_Blackbird
10 010.Red_winged_Blackbird
11 011.Rusty_Blackbird
12 012.Yellow_headed_Blackbird
13 013.Bobolink
14 014.Indigo_Bunting
15 015.Lazuli_Bunting
16 016.Painted_Bunting
17 017.Cardinal
18 018.Spotted_Catbird
19 019.Gray_Catbird
20 020.Yellow_breasted_Chat
21 021.Eastern_Towhee
22 022.Chuck_will_Widow
23 023.Brandt_Cormorant
24 024.Red_faced_Cormorant
25 025.Pelagic_Cormorant
26 026.Bronzed_Cowbird
27 027.Shiny_Cowbird
28 028.Brown_Creeper
29 029.American_Crow
30 030.Fish_Crow
31 031.Black_billed_Cuckoo
32 032.Mangrove_Cuckoo
33 033.Yellow_billed_Cuckoo
34 034.Gray_crowned_Rosy_Finch
35 035.Purple_Finch
36 036.Northern_Flicker
37 037.Acadian_Flycatcher
38 038.Great_Crested_Flycatcher
39 039.Least_Flycatcher
40 040.Olive_sided_Flycatcher
41 041.Scissor_tailed_Flycatcher
42 042.Vermilion_Flycatcher
43 043.Yellow_bellied_Flycatcher
44 044.Frigatebird
45 045.Northern_Fulmar
46 046.Gadwall
47 047.American_Goldfinch
48 048.European_Goldfinch
49 049.Boat_tailed_Grackle
50 050.Eared_Grebe
51 051.Horned_Grebe
52 052.Pied_billed_Grebe
53 053.Western_Grebe
54 054.Blue_Grosbeak
55 055.Evening_Grosbeak
56 056.Pine_Grosbeak
57 057.Rose_breasted_Grosbeak
58 058.Pigeon_Guillemot
59 059.California_Gull
60 060.Glaucous_winged_Gull
61 061.Heermann_Gull
62 062.Herring_Gull
63 063.Ivory_Gull
64 064.Ring_billed_Gull
65 065.Slaty_backed_Gull
66 066.Western_Gull
67 067.Anna_Hummingbird
68 068.Ruby_throated_Hummingbird
69 069.Rufous_Hummingbird
70 070.Green_Violetear
71 071.Long_tailed_Jaeger
72 072.Pomarine_Jaeger
73 073.Blue_Jay
74 074.Florida_Jay
75 075.Green_Jay
76 076.Dark_eyed_Junco
77 077.Tropical_Kingbird
78 078.Gray_Kingbird
79 079.Belted_Kingfisher
80 080.Green_Kingfisher
81 081.Pied_Kingfisher
82 082.Ringed_Kingfisher
83 083.White_breasted_Kingfisher
84 084.Red_legged_Kittiwake
85 085.Horned_Lark
86 086.Pacific_Loon
87 087.Mallard
88 088.Western_Meadowlark
89 089.Hooded_Merganser
90 090.Red_breasted_Merganser
91 091.Mockingbird
92 092.Nighthawk
93 093.Clark_Nutcracker
94 094.White_breasted_Nuthatch
95 095.Baltimore_Oriole
96 096.Hooded_Oriole
97 097.Orchard_Oriole
98 098.Scott_Oriole
99 099.Ovenbird
100 100.Brown_Pelican
101 101.White_Pelican
102 102.Western_Wood_Pewee
103 103.Sayornis
104 104.American_Pipit
105 105.Whip_poor_Will
106 106.Horned_Puffin
107 107.Common_Raven
108 108.White_necked_Raven
109 109.American_Redstart
110 110.Geococcyx
111 111.Loggerhead_Shrike
112 112.Great_Grey_Shrike
113 113.Baird_Sparrow
114 114.Black_throated_Sparrow
115 115.Brewer_Sparrow
116 116.Chipping_Sparrow
117 117.Clay_colored_Sparrow
118 118.House_Sparrow
119 119.Field_Sparrow
120 120.Fox_Sparrow
121 121.Grasshopper_Sparrow
122 122.Harris_Sparrow
123 123.Henslow_Sparrow
124 124.Le_Conte_Sparrow
125 125.Lincoln_Sparrow
126 126.Nelson_Sharp_tailed_Sparrow
127 127.Savannah_Sparrow
128 128.Seaside_Sparrow
129 129.Song_Sparrow
130 130.Tree_Sparrow
131 131.Vesper_Sparrow
132 132.White_crowned_Sparrow
133 133.White_throated_Sparrow
134 134.Cape_Glossy_Starling
135 135.Bank_Swallow
136 136.Barn_Swallow
137 137.Cliff_Swallow
138 138.Tree_Swallow
139 139.Scarlet_Tanager
140 140.Summer_Tanager
141 141.Artic_Tern
142 142.Black_Tern
143 143.Caspian_Tern
144 144.Common_Tern
145 145.Elegant_Tern
146 146.Forsters_Tern
147 147.Least_Tern
148 148.Green_tailed_Towhee
149 149.Brown_Thrasher
150 150.Sage_Thrasher
151 151.Black_capped_Vireo
152 152.Blue_headed_Vireo
153 153.Philadelphia_Vireo
154 154.Red_eyed_Vireo
155 155.Warbling_Vireo
156 156.White_eyed_Vireo
157 157.Yellow_throated_Vireo
158 158.Bay_breasted_Warbler
159 159.Black_and_white_Warbler
160 160.Black_throated_Blue_Warbler
161 161.Blue_winged_Warbler
162 162.Canada_Warbler
163 163.Cape_May_Warbler
164 164.Cerulean_Warbler
165 165.Chestnut_sided_Warbler
166 166.Golden_winged_Warbler
167 167.Hooded_Warbler
168 168.Kentucky_Warbler
169 169.Magnolia_Warbler
170 170.Mourning_Warbler
171 171.Myrtle_Warbler
172 172.Nashville_Warbler
173 173.Orange_crowned_Warbler
174 174.Palm_Warbler
175 175.Pine_Warbler
176 176.Prairie_Warbler
177 177.Prothonotary_Warbler
178 178.Swainson_Warbler
179 179.Tennessee_Warbler
180 180.Wilson_Warbler
181 181.Worm_eating_Warbler
182 182.Yellow_Warbler
183 183.Northern_Waterthrush
184 184.Louisiana_Waterthrush
185 185.Bohemian_Waxwing
186 186.Cedar_Waxwing
187 187.American_Three_toed_Woodpecker
188 188.Pileated_Woodpecker
189 189.Red_bellied_Woodpecker
190 190.Red_cockaded_Woodpecker
191 191.Red_headed_Woodpecker
192 192.Downy_Woodpecker
193 193.Bewick_Wren
194 194.Cactus_Wren
195 195.Carolina_Wren
196 196.House_Wren
197 197.Marsh_Wren
198 198.Rock_Wren
199 199.Winter_Wren
200 200.Common_Yellowthroat
"""

# Разбиваем строки по переводу строки
lines = data.strip().split('\n')

# Извлекаем названия птиц, удаляя префиксы с номерами
bird_names = [line.split(' ', 1)[1].split('.', 1)[1] for line in lines]



# Выводим результат
#for name in bird_names:
#    print(name)

In [39]:
bird_names

['Black_footed_Albatross',
 'Laysan_Albatross',
 'Sooty_Albatross',
 'Groove_billed_Ani',
 'Crested_Auklet',
 'Least_Auklet',
 'Parakeet_Auklet',
 'Rhinoceros_Auklet',
 'Brewer_Blackbird',
 'Red_winged_Blackbird',
 'Rusty_Blackbird',
 'Yellow_headed_Blackbird',
 'Bobolink',
 'Indigo_Bunting',
 'Lazuli_Bunting',
 'Painted_Bunting',
 'Cardinal',
 'Spotted_Catbird',
 'Gray_Catbird',
 'Yellow_breasted_Chat',
 'Eastern_Towhee',
 'Chuck_will_Widow',
 'Brandt_Cormorant',
 'Red_faced_Cormorant',
 'Pelagic_Cormorant',
 'Bronzed_Cowbird',
 'Shiny_Cowbird',
 'Brown_Creeper',
 'American_Crow',
 'Fish_Crow',
 'Black_billed_Cuckoo',
 'Mangrove_Cuckoo',
 'Yellow_billed_Cuckoo',
 'Gray_crowned_Rosy_Finch',
 'Purple_Finch',
 'Northern_Flicker',
 'Acadian_Flycatcher',
 'Great_Crested_Flycatcher',
 'Least_Flycatcher',
 'Olive_sided_Flycatcher',
 'Scissor_tailed_Flycatcher',
 'Vermilion_Flycatcher',
 'Yellow_bellied_Flycatcher',
 'Frigatebird',
 'Northern_Fulmar',
 'Gadwall',
 'American_Goldfinch',
 'Euro

In [4]:
# Путь к данным
data_dir = '/home/oldmovielover/Загрузки/CUB_200_2011/CUB_200_2011'
images_dir = os.path.join(data_dir, 'images')
train_test_split_file = os.path.join(data_dir, 'train_test_split.txt')

# Определим трансформации для данных
data_transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Рандомное изменение размера
    transforms.RandomRotation(15),      # Рандомный поворот на 15 градусов
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data_transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Считываем информацию о разделении на train и test
def get_train_test_split():
    with open(train_test_split_file, 'r') as file:
        lines = file.readlines()
    
    # Парсим строки, где 0 - test, 1 - train
    train_indices = []
    valid_indices = []
    
    for line in lines:
        image_id, is_train = line.strip().split()
        image_id = int(image_id) - 1  # Нумерация с 1, поэтому уменьшим на 1
        if is_train == '1':
            train_indices.append(image_id)
        else:
            valid_indices.append(image_id)
    
    return train_indices, valid_indices

# Кастомный датасет для загрузки данных
class CUBDataset(Dataset):
    def __init__(self, images_dir, split_indices, transform=None):
        self.images_dir = images_dir
        self.split_indices = split_indices
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Считываем метки классов
        self.class_names = sorted(os.listdir(images_dir))  # Сортируем по имени папки

        # Заполняем пути к изображениям и их метки
        for class_idx, class_name in enumerate(self.class_names):
            class_folder = os.path.join(self.images_dir, class_name)
            
            if os.path.exists(class_folder):
                # Собираем изображения для каждой папки класса
                for image_name in os.listdir(class_folder):
                    image_path = os.path.join(class_folder, image_name)
                    self.image_paths.append(image_path)
                    self.labels.append(class_idx)
            else:
                print(f"Папка для класса {class_name} не найдена: {class_folder}")

        # Используем только индексы из train или valid
        self.image_paths = [self.image_paths[i] for i in self.split_indices]
        self.labels = [self.labels[i] for i in self.split_indices]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# Загружаем train и test индексы
train_indices, valid_indices = get_train_test_split()

# Создаем тренировочный и тестовый датасеты
train_dataset = CUBDataset(images_dir, train_indices, transform=data_transform_train)
valid_dataset = CUBDataset(images_dir, valid_indices, transform=data_transform_test)

# Создаем DataLoader для тренировки и тестирования
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

# Проверка: выводим количество элементов в датасетах
print(f"Training set size: {len(train_loader.dataset)}")
print(f"Validation set size: {len(valid_loader.dataset)}")

Training set size: 5994
Validation set size: 5794


In [5]:
next(iter(valid_loader))[-1]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1])

In [6]:
train_iterator = iter(valid_loader)

# Извлекаем три следующих батча
batch_1 = next(train_iterator)
batch_2 = next(train_iterator)
batch_3 = next(train_iterator)
batch_4 = next(train_iterator)
batch_5 = next(train_iterator)
batch_6 = next(train_iterator)

# Каждый батч — это кортеж (image, label)
# Можете получить данные и метки из этих батчей
images_1, labels_1 = batch_1
images_2, labels_2 = batch_2
images_3, labels_3 = batch_3
images_4, labels_4 = batch_4
images_5, labels_5 = batch_5
images_6, labels_6 = batch_6

# Выводим размеры изображений и меток для проверки
print("Batch 1 Labels shape:", labels_1)
print("Batch 2 Labels shape:", labels_2)
print("Batch 3 Labels shape:", labels_3)
print("Batch 1 Labels shape:", labels_4)
print("Batch 2 Labels shape:", labels_5)
print("Batch 3 Labels shape:", labels_6)

Batch 1 Labels shape: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1])
Batch 2 Labels shape: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2])
Batch 3 Labels shape: tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        3, 3, 3, 3, 3, 3, 3, 3])
Batch 1 Labels shape: tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4])
Batch 2 Labels shape: tensor([4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6])
Batch 3 Labels shape: tensor([6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        8, 8, 8, 8, 8, 8, 8, 8])


In [36]:
class MyDenseNet(nn.Module):
  def __init__(self) -> None:
    super().__init__()

    self.model = models.densenet121(weights=DenseNet121_Weights.DEFAULT).to(DEVICE)
    self.model.classifier = nn.Linear(in_features=1024, out_features=200)

    for param in self.model.parameters():
        param.requires_grad = False

    for param in self.model.features.denseblock4.denselayer16.parameters():
        param.requires_grad = True
      
    self.model.classifier.weight.requires_grad = True
    self.model.classifier.bias.requires_grad = True

  def forward(self, x):
      return self.model(x)

model = MyDenseNet()
model.to(DEVICE);

In [37]:
for param in model.parameters():
    print(param.requires_grad)

False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
Fals

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
criterion = torch.nn.CrossEntropyLoss()

In [11]:
train_epoch_acc = []
train_epoch_losses = []
valid_epoch_losses = []
valid_epoch_acc = []

# Перебор эпох
for epoch in range(1):
    # Обучение модели
    model.train()
    loss_batch = []
    acc_batch = []

    for images, labels in train_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        # Получение предсказаний
        preds = model(images)  # без squeeze(-1), если выход имеет правильную форму

        # Вычисление потерь
        loss = criterion(preds, labels)
        loss_batch.append(loss.item())

        # Вычисление точности
        accuracy = (preds.argmax(dim=1) == labels).cpu().numpy().mean()
        acc_batch.append(accuracy)

        # Обратное распространение и обновление весов
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Сохранение статистики
    train_epoch_losses.append(np.mean(loss_batch))
    train_epoch_acc.append(np.mean(acc_batch))

    # Оценка модели
    model.eval()
    loss_batch = []
    acc_batch = []
    for images, labels in valid_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        # Без обновления градиентов
        with torch.no_grad():
            preds = model(images)

        # Вычисление потерь
        loss = criterion(preds, labels)
        loss_batch.append(loss.item())

        # Вычисление точности
        accuracy = (preds.argmax(dim=1) == labels).cpu().numpy().mean()
        acc_batch.append(accuracy)

    # Сохранение статистики для валидации
    valid_epoch_losses.append(np.mean(loss_batch))
    valid_epoch_acc.append(np.mean(acc_batch))

    # Вывод статистики
    print(f'Epoch: {epoch}  loss_train: {train_epoch_losses[-1]:.3f}, loss_valid: {valid_epoch_losses[-1]:.3f}')
    print(f'\t  metrics_train: {train_epoch_acc[-1]:.3f}, metrics_valid: {valid_epoch_acc[-1]:.3f}')

Epoch: 0  loss_train: 4.480, loss_valid: 3.134
	  metrics_train: 0.126, metrics_valid: 0.299


In [None]:
for name, module in self.model.named_children():
            if "denselayer16" in name:  # Можете использовать имя нужного слоя
                for param in module.parameters():
                    param.requires_grad = True

In [12]:
for param in model.parameters():
    print(param.requires_grad)

False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
Fals

In [40]:
# Устройство (CPU или GPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Пути к данным
DATA_DIR = "/home/oldmovielover/Загрузки/CUB_200_2011/CUB_200_2011/images"  # Путь к папке с изображениями
IMAGES_FILE = "/home/oldmovielover/Загрузки/CUB_200_2011/CUB_200_2011/images.txt"
LABELS_FILE = "/home/oldmovielover/Загрузки/CUB_200_2011/CUB_200_2011/image_class_labels.txt"
SPLIT_FILE = "/home/oldmovielover/Загрузки/CUB_200_2011/CUB_200_2011/train_test_split.txt"

# Гиперпараметры
BATCH_SIZE = 32
NUM_CLASSES = 200
LEARNING_RATE = 0.001
EPOCHS = 25

# Загрузка данных из файлов
def load_data(images_file, labels_file, split_file, data_dir):
    # Загрузка всех изображений и их меток
    images_df = pd.read_csv(images_file, sep=' ', header=None, names=['id', 'file_path'])
    labels_df = pd.read_csv(labels_file, sep=' ', header=None, names=['id', 'label'])
    split_df = pd.read_csv(split_file, sep=' ', header=None, names=['id', 'is_train'])

    # Объединение всех данных в один DataFrame
    data = images_df.merge(labels_df, on='id').merge(split_df, on='id')
    data['file_path'] = data['file_path'].apply(lambda x: os.path.join(data_dir, x))
    
    # Разделение на обучающую и тестовую выборки
    train_data = data[data['is_train'] == 1]
    val_data = data[data['is_train'] == 0]

    return train_data['file_path'].values, train_data['label'].values - 1, \
           val_data['file_path'].values, val_data['label'].values - 1

# Загрузка данных
train_image_paths, train_labels, val_image_paths, val_labels = load_data(IMAGES_FILE, LABELS_FILE, SPLIT_FILE, DATA_DIR)

# Класс для набора данных
class CUBDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# Трансформации данных
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Создание наборов данных и DataLoader
train_dataset = CUBDataset(train_image_paths, train_labels, transform=train_transforms)
val_dataset = CUBDataset(val_image_paths, val_labels, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Модель DenseNet121
class MyDenseNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Загрузка предобученной модели
        self.model = models.densenet121(weights=DenseNet121_Weights.DEFAULT).to(DEVICE)
        self.model.classifier = nn.Linear(in_features=1024, out_features=200)
    
        for param in self.model.parameters():
            param.requires_grad = False
    
        for param in self.model.features.denseblock4.denselayer16.parameters():
            param.requires_grad = True
          
        self.model.classifier.weight.requires_grad = True
        self.model.classifier.bias.requires_grad = True
    
    def forward(self, x):
        return self.model(x)

# Инициализация модели
model = MyDenseNet().to(DEVICE)

# Определение функции потерь и оптимизатора
criterion = nn.CrossEntropyLoss()  # Если нужно использовать веса классов, добавьте weight=class_weights
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Функция для обучения модели
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs):
    best_accuracy = 0.0

    for epoch in range(epochs):
        # Обучение
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Оценка на валидационной выборке
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Validation Accuracy: {accuracy:.4f}")


In [None]:
train_model(model, criterion, optimizer, train_loader, val_loader, EPOCHS)

In [None]:
import torch
import torch.nn as nn
from torchvision import models

class MyResNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        # Загружаем ResNet-50
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
        
        # Заменяем последний слой (fc) для 200 классов (CUB-200-2011)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, 200),  # 200 классов для CUB-200-2011
            nn.LogSoftmax(dim=1)
        )
        
        # Замораживаем все параметры
        for param in self.model.parameters():
            param.requires_grad = False
        
        # Размораживаем последние слои (например, layer4)
        for param in self.model.layer4.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)


model = MyResNet()
model.to(DEVICE)