In [30]:
import datetime

import torch
import random


# устанавливаем seed, чтобы результаты не изменялись при не изменение чего-либо
torch.manual_seed(666)
random.seed(666)

In [31]:
from support_module import ImageToNumDataset, NoMaskModel

In [32]:
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps:0" if torch.backends.mps.is_available() else "cpu")
DEVICE

device(type='mps', index=0)

In [33]:
import torch
from torchvision.transforms.v2 import ToDtype, Normalize, Compose, PILToTensor

transform = Compose([
    PILToTensor(),
    ToDtype(torch.float32, scale=True),
    Normalize((0.5,), (0.5,))
])

In [34]:
dataset = ImageToNumDataset("data/train_images", answers_file="data/train_answers.csv", transform=transform)

In [35]:
from torch.utils.data import DataLoader, random_split

train_dataset, validation_dataset = random_split(dataset, (0.8, 0.2))

train_dataloader = DataLoader(train_dataset, batch_size=2**5, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=2**5, shuffle=False)

In [36]:
model = NoMaskModel()
model = model.to(DEVICE)
# model.load_state_dict(torch.load("models/model.pt"))

In [37]:
from torch import nn, optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [38]:
from ignite.metrics import Accuracy, Loss
from ignite.engine import create_supervised_trainer, create_supervised_evaluator

trainer = create_supervised_trainer(model, optimizer, criterion, device=DEVICE)
evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy(), 'nll': Loss(criterion)}, device=DEVICE)

In [39]:
# Сбор потерь и метрик для построения графиков
train_loss_values = []
validation_loss_values = []
validation_accuracy_values = []

In [40]:
import logging


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logging.getLogger("ignite.engine.engine.Engine").setLevel(logging.WARNING)

In [41]:
EPOCHS = 50

In [42]:
from ignite.engine import Events


@trainer.on(Events.EPOCH_STARTED)
def log_training_start(engine):
    logging.info(f"Starting learning at epoch {engine.state.epoch} in {datetime.datetime.now()}")


# @trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    logging.info(f"End learning at epoch {engine.state.epoch} in {datetime.datetime.now()}")
    # Запуск оценки модели на обучающем наборе данных
    evaluator.run(train_dataloader)
    metrics = evaluator.state.metrics
    # Сбор и вывод средней точности и потерь на обучающем наборе
    train_loss_values.append(metrics['nll'])
    logging.info(
        f"Training Results - Epoch: {engine.state.epoch}  "
        f"Avg accuracy: {metrics['accuracy']:.2f} "
        f"Avg loss: {metrics['nll']:.2f}"
    )
    
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    logging.info(f"End learning at epoch {engine.state.epoch} in {datetime.datetime.now()}")
    logging.info(f"Starting validation on epoch {engine.state.epoch}")
    # Запуск оценки модели на валидационном наборе данных
    evaluator.run(validation_dataloader)
    metrics = evaluator.state.metrics
    # Сбор и вывод средней точности и потерь на валидационном наборе
    validation_loss_values.append(metrics['nll'])
    validation_accuracy_values.append(metrics['accuracy'])
    logging.info(
        f"Validation Results - Epoch: {engine.state.epoch}  "
        f"Avg accuracy: {metrics['accuracy']:.3f} "
        f"Avg loss: {metrics['nll']:.3f}"
    )
    logging.info(f"End of validation on epoch {engine.state.epoch}")
    torch.save(model.state_dict(), "models/model.pt")

In [None]:
trainer.run(train_dataloader, max_epochs=EPOCHS)

INFO:root:Starting learning at epoch 1 in 2023-12-09 21:27:46.954689
INFO:root:End learning at epoch 1 in 2023-12-09 21:28:51.970737
INFO:root:Starting validation on epoch 1
INFO:root:Validation Results - Epoch: 1  Avg accuracy: 0.747 Avg loss: 0.810
INFO:root:End of validation on epoch 1
INFO:root:Starting learning at epoch 2 in 2023-12-09 21:29:03.702151
INFO:root:End learning at epoch 2 in 2023-12-09 21:30:08.848099
INFO:root:Starting validation on epoch 2
INFO:root:Validation Results - Epoch: 2  Avg accuracy: 0.783 Avg loss: 0.772
INFO:root:End of validation on epoch 2
INFO:root:Starting learning at epoch 3 in 2023-12-09 21:30:20.543127
INFO:root:End learning at epoch 3 in 2023-12-09 21:36:30.540234
INFO:root:Starting validation on epoch 3
INFO:root:Validation Results - Epoch: 3  Avg accuracy: 0.800 Avg loss: 0.756
INFO:root:End of validation on epoch 3
INFO:root:Starting learning at epoch 4 in 2023-12-09 21:42:54.743468
INFO:root:End learning at epoch 4 in 2023-12-09 23:00:10.8504

In [None]:
from matplotlib import pyplot as plt

# Графики обучения
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_loss_values, label='Training Loss')
plt.plot(validation_loss_values, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(validation_accuracy_values, label='Validation Accuracy', color='red')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
import torch
import random


# устанавливаем seed, чтобы результаты не изменялись при не изменение чего-либо
torch.manual_seed(666)
random.seed(666)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_model = NoMaskModel()
test_model = test_model.to(DEVICE)
test_model.load_state_dict(torch.load("models/model.pt"))
test_dataset = ImageToNumDataset("data/test_images", transform=transform)

In [None]:
import csv
from IPython.display import clear_output

test_model.eval()
len_dataset = len(test_dataset)
with open("answer.csv", "w") as file:
    writer = csv.writer(file, delimiter=",")
    writer.writerow(["id", "target_feature"])
    for index, image in enumerate(test_dataset):
        with torch.no_grad():
            pred_y = test_model(image.unsqueeze(0))
        answer = max(((n, i) for i, n in enumerate(pred_y[0])), key=lambda x: x[0])[1]
        writer.writerow([index, answer])
        if index % 10 == 0 or index % 10 == 9:
            print(f"{(index / len_dataset) * 100:.2f}%")
print("100%")