In [104]:
!pip install pygame

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils
import torch.optim as optim
import torchvision.transforms.v2 as v2

import pygame as pg

import os
import json
from random import randint

from PIL import Image



In [106]:
train_data = {"qty": 1000, "dir": "train"}
test_data = {"qty": 400, "dir": "test"}

dataset_dir = "Sun_dataset"
coords_file = "coords.json"

if not os.path.exists(dataset_dir):
    os.mkdir(dataset_dir)
    if not os.path.exists(os.path.join(dataset_dir, train_data["dir"])):
        os.mkdir(os.path.join(dataset_dir, train_data["dir"]))
    if not os.path.exists(os.path.join(dataset_dir, test_data["dir"])):
        os.mkdir(os.path.join(dataset_dir, test_data["dir"]))


In [108]:
sun = pg.image.load("imgs/sun.png")
backgrounds = [pygame.image.load(f"imgs/{img}") for img in os.listdir("imgs") if img.startswith("img")]

backgrounds_len = len(backgrounds)

for data in (train_data, test_data):
    sun_coords = {}

    for i in range(1, data["qty"]):
        img_name = f"sun_gen_{i}.png"
        random_background = backgrounds[randint(0, backgrounds_len) - 1].copy()

        for i in range(randint(20, 100)):
            pg.draw.circle(random_background, 
                           color=tuple([randint(0, 255) for _ in range(3)]), 
                           center=tuple([randint(0, 255) for _ in range(2)]), 
                           radius=1)

        x = randint(32, 256 - 32)
        y = randint(32, 256 - 32)
        sun_coords[img_name] = (x, y)
        random_background.blit(sun, (x-32, y-32))

        pg.image.save(random_background, os.path.join(dataset_dir, data['dir'], img_name))

    with open(os.path.join(dataset_dir, data["dir"], coords_file), "w") as coordinates_file:
        json.dump(sun_coords, coordinates_file)

In [178]:
class SunDataset(utils.Dataset):
    def __init__(self, path, train=True, tfs=None):
        self.path = os.path.join(dataset_dir, "train" if train else "test")
        self.tfs = tfs

        with open(os.path.join(self.path, coords_file), "r") as cf:
            self.coords = json.load(cf)

        self.length = len(self.coords)
        self.imgs = tuple(self.coords.keys())
        self.targets = tuple(self.coords.values())

    def __getitem__(self, item):
        img_path = os.path.join(self.path, self.imgs[item])
        image = Image.open(fp=img_path).convert("RGB")

        if tfs:
            image = self.tfs(image)

        return image, torch.tensor(self.targets[item], dtype=torch.float32)

    def __len__(self):
        return self.length

tfs = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
dataset_unit = SunDataset(dataset_dir, tfs=tfs)
train_data = utils.DataLoader(dataset_unit, batch_size=32, shuffle=True)

In [180]:
model = nn.Sequential( # Сама модель 
    nn.Conv2d(3, 32, 3, padding="same"),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 8, 3, padding="same"),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(8, 4, 3, padding="same"),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(4096, 128),
    nn.ReLU(),
    #nn.BatchNorm1d(128),
    nn.Linear(128, 2)
)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001) # Оптимизатор
loss_func = nn.MSELoss() # функция потерь для регрессии

In [182]:
from tqdm import tqdm

epochs = 8
model.train()

for epoch in range(epochs):
    tqdm_train = tqdm(train_data, leave=True)
    loss_mean = 0
    lm_count = 0
    
    for x_train, y_train in tqdm_train:
        prediction = model(x_train)
        loss = loss_func(prediction, y_train)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        lm_count += 1
        loss_mean = 1/lm_count * loss.item() + (1 - 1/lm_count) * loss_mean
        tqdm_train.set_description(f"Эпоха {epoch+1} из {epochs}, значения потерь - {loss_mean: .2f}")


#model_state = model.state_dict()
#torch.save(model_state, 'model_state_1.tar')

Эпоха 1 из 8, значения потерь -  11071.83: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]
Эпоха 2 из 8, значения потерь -  3548.49: 100%|██████████| 32/32 [00:30<00:00,  1.06it/s]
Эпоха 3 из 8, значения потерь -  3374.26: 100%|██████████| 32/32 [00:27<00:00,  1.15it/s]
Эпоха 4 из 8, значения потерь -  3241.37: 100%|██████████| 32/32 [00:26<00:00,  1.19it/s]
Эпоха 5 из 8, значения потерь -  3300.07: 100%|██████████| 32/32 [00:26<00:00,  1.23it/s]
Эпоха 6 из 8, значения потерь -  3538.10: 100%|██████████| 32/32 [00:25<00:00,  1.27it/s]
Эпоха 7 из 8, значения потерь -  3396.40: 100%|██████████| 32/32 [00:24<00:00,  1.28it/s]
Эпоха 8 из 8, значения потерь -  3092.12: 100%|██████████| 32/32 [00:24<00:00,  1.28it/s]


In [118]:
test_dataset = SunDataset(dataset_dir, train=False, tfs=tfs)
test_data = utils.DataLoader(test_dataset, batch_size=96, shuffle=False, drop_last=True)

model_quality = 0
stage = 0

model.eval()

tqdm_test = tqdm(test_data, leave=False)
for x_test, y_test in tqdm_test:
    with torch.no_grad():
        predict = model(x_test)
        model_quality += loss_func(predict, y_test).item()
        stage += 1

result = model_quality / stage
result
        
    


                                             

453.0337829589844