In [37]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image

import matplotlib.pyplot as plt

In [38]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

Загрузим дата сет:

In [39]:
dataset_folder = './data_set/'
train_folder = 'training_set/training_set'
test_folder = 'test_set/test_set'

In [40]:
train_set = torchvision.datasets.ImageFolder(dataset_folder+train_folder)
train_set.classes

['cats', 'dogs']

Проверим дата сет

In [41]:
def get_img(dataset:torchvision.datasets.ImageFolder, index):
    img = Image.open(dataset.imgs[index][0])
    lable = dataset.imgs[index][1]
    return (img, lable)

In [42]:
type(get_img(train_set, 10)[0])

PIL.JpegImagePlugin.JpegImageFile

Напишем класс датасет:

In [43]:
class MyImage(torch.utils.data.Dataset):
    def __init__(self, dataset, transform = None):
        super(MyImage, self).__init__()
        def chekRGB(dataset):
            data = []
            line = len(dataset)
            for index in range(line):
                img = get_img(dataset, index)
                if img[0].getbands() == ("R", "G", "B"):
                    data.append(img)
            return(data)

        self.data = chekRGB(dataset)
        self.transform = transform

    def __getitem__(self, index):
        img, lable = self.data[index]
        if self.transform != None:
            img = self.transform(img)
        return (img, lable)

    def __len__(self):
        return len(self.data)

Создадим трансформеры изибражения:

In [44]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5]*3, [0.5]*3)
])

In [45]:
train_dataset = MyImage(train_set, transform)

In [46]:
train_dataset[4001][0]

tensor([[[ 0.1294,  0.0667,  0.0431,  ...,  0.0667,  0.0667,  0.0588],
         [ 0.1451,  0.0980,  0.1059,  ...,  0.0667,  0.0667,  0.0588],
         [ 0.0824,  0.0667,  0.0824,  ...,  0.0431,  0.0431,  0.0431],
         ...,
         [ 0.4275,  0.4667,  0.5059,  ..., -0.2235, -0.2392, -0.2471],
         [ 0.3961,  0.4275,  0.4745,  ..., -0.2157, -0.2235, -0.2235],
         [ 0.3569,  0.3882,  0.4353,  ..., -0.2157, -0.2078, -0.2078]],

        [[-0.0588, -0.1294, -0.1843,  ..., -0.0510, -0.0588, -0.0667],
         [-0.0667, -0.1137, -0.1373,  ..., -0.0510, -0.0588, -0.0667],
         [-0.1373, -0.1686, -0.1529,  ..., -0.0745, -0.0824, -0.0824],
         ...,
         [ 0.5843,  0.6235,  0.6627,  ..., -0.2706, -0.2863, -0.2941],
         [ 0.5529,  0.5843,  0.6314,  ..., -0.2627, -0.2706, -0.2706],
         [ 0.5137,  0.5451,  0.5922,  ..., -0.2627, -0.2549, -0.2549]],

        [[-0.4824, -0.5686, -0.6235,  ..., -0.2314, -0.2392, -0.2471],
         [-0.4745, -0.5373, -0.5608,  ..., -0

НАпишем dataLoader

In [47]:
batch = 20

In [48]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)

Зададим модель:

In [49]:
model = torchvision.models.resnet50(pretrained=True)


ftrs = model.fc.in_features
model.fc = nn.Linear(ftrs, 2)

model.to(device)



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

Тренировка:

In [50]:
from tqdm import tqdm

In [51]:
def train(epoch, model, optimizer, criterion, dataloader, save = False):
    model.train()
    # loop over the dataset multiple times
    for epoch in range(epoch):
        running_loss = 0.0
        for datainp in (pdbar := tqdm(dataloader)):
            inputs, labels = datainp
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
           
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            pdbar.set_description(f'epoch: {epoch}\tloss: {running_loss:.3F}')
        if save:
            torch.save(model, f'my_model_epoch_{epoch}.pt')
    print('Finished Training')
    model.eval()

In [52]:
train_model  = False

epoch = 10
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.CrossEntropyLoss()

In [53]:
if train_model:
    train(epoch, model, optimizer, criterion, train_dataloader)

загрузка модели:

In [54]:
load_save = True

In [55]:
if load_save:
    model = torch.load('my_model_epoch.pt') 

Сохранение:

In [56]:
torch.save(model, f'my_model_epoch.pt')

Тестировка:

In [57]:
def test(model, dataloader):
    acc = 0.
    loss = 0.
    with torch.no_grad():
        for datainp in (pdbar := tqdm(dataloader)):
            inputs, labels = datainp
            inputs = inputs.to(device)
            labels = labels.to(device)

            output = model(inputs)
            _, out = torch.max(output, 1)
            if out == labels:
                acc+=1
            else:
                loss+=1
            pdbar.set_description(f'acc: {acc/(loss+acc)*100:.3F}%\tloss: {loss}')

In [58]:
test_set = torchvision.datasets.ImageFolder(dataset_folder+test_folder)

test_dataset = MyImage(test_set, transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)

OSError: [Errno 24] Too many open files: './data_set/test_set/test_set\\cats\\cat.4183.jpg'

In [59]:
test(model, test_dataloader)

acc: 65.546%	loss: 697.0: 100%|██████████| 2023/2023 [00:37<00:00, 53.76it/s]
