<a href="https://colab.research.google.com/github/Vimp17/py/blob/main/CNN_REG_Sunposition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import zipfile

# Укажите путь к архиву и папку для распаковки
archive_path = '/content/dataset_gen_reg.zip'  # Пример: '/content/drive/MyDrive/archive.zip'
extract_folder = '/content/images'

# Распаковка
with zipfile.ZipFile(archive_path, 'r') as zip_ref:
    zip_ref.extractall(extract_folder)

print("Архив успешно распакован!")

In [None]:
import os
from random import randint
import json
import pygame

train_data = {'total': 10000, 'dir': "train"}
test_data = {'total': 1000, 'dir': "test"}
total_bk = 10
total_cls = 4
dir_out = 'dataset_reg'
file_format = 'format.json'
cls = [(255, 255, 255), (0, 0, 255), (0, 255, 0), (255, 0, 0)]

if not os.path.exists(dir_out):
    os.mkdir(dir_out)
    if not os.path.exists(os.path.join(dir_out, "train")):
        os.mkdir(os.path.join(dir_out, "train"))
    if not os.path.exists(os.path.join(dir_out, "test")):
        os.mkdir(os.path.join(dir_out, "test"))

sun = pygame.image.load("/content/images/dataset_gen_reg/images/sun64.png")
backs = [pygame.image.load(f"/content/images/dataset_gen_reg/images/back_{n}.png") for n in range(1, total_bk+1)]

for info in (train_data, test_data):
    sun_coords = dict()

    for i in range(1, info['total']+1):
        file_out = f"sun_reg_{i}.png"
        im = backs[randint(0, total_bk-1)].copy()

        for _ in range(randint(20, 100)):
            x0 = randint(0, 256)
            y0 = randint(0, 256)
            pygame.draw.circle(im, cls[randint(0, total_cls-1)], (x0, y0), 1)

        x = randint(32, 256 - 32)
        y = randint(32, 256 - 32)
        sun_coords[file_out] = (x, y)
        im.blit(sun, (x-32, y-32))

        pygame.image.save(im, os.path.join(dir_out, info['dir'], file_out))

    fp = open(os.path.join(dir_out, info['dir'], file_format), "w")
    json.dump(sun_coords, fp)
    fp.close()


In [None]:
import os
import json
from PIL import Image

import torch
import torch.utils.data as data
import torchvision.transforms.v2 as tfs
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


class SunDataset(data.Dataset):
    def __init__(self, path, train=True, transform=None):
        self.path = os.path.join(path, "train" if train else "test")
        self.transform = transform

        with open(os.path.join(self.path, "format.json"), "r") as fp:
            self.format = json.load(fp)

        self.length = len(self.format)
        self.files = tuple(self.format.keys())
        self.targets = tuple(self.format.values())

    def __getitem__(self, item):
        path_file = os.path.join(self.path, self.files[item])
        img = Image.open(path_file).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, torch.tensor(self.targets[item], dtype=torch.float32)

    def __len__(self):
        return self.length

transforms = tfs.Compose([tfs.ToImage(), tfs.ToDtype(torch.float32, scale=True)])
d_train = SunDataset("dataset_reg", transform=transforms)
train_data = data.DataLoader(d_train, batch_size=32, shuffle=True)

model = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 8, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(8, 4, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(4096,128),
    nn.ReLU(),
    nn.Linear(128, 2)

)
optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay=0.001)
loss_func = nn.MSELoss()

epochs = 5
model.train()

for i in range(epochs):
    loss_mean = 0
    lm_count = 0

    train_tqdm = tqdm(train_data, leave=True)
    for x_train, y_train in train_tqdm:
         predict = model(x_train)
         loss = loss_func(predict, y_train)

         optimizer.zero_grad()
         loss.backward()
         optimizer.step()

         lm_count += 1
         loss_mean = 1/lm_count * loss.item() + (1 - 1/lm_count) * loss_mean
         train_tqdm.set_description(f"Epoch [{i+1}/{epochs}], loss_mean={loss_mean:.3f}")
st = model.state_dict()
torch.save(st, 'model_sun.tar')

d_test = SunDataset("dataset_reg", train=False, transform=transforms)
test_data = data.DataLoader(d_test, batch_size=50, shuffle=False)

Q = 0
count = 0
model.eval()

test_tqdm = tqdm(test_data, leave=True)
for x_test, y_test in test_tqdm:
         with torch.no_grad():
              pre = model(x_test)
              Q += loss_func(pre, y_test).item()
              count += 1

Q /= count

print(Q)



In [None]:
from PIL import Image
import json
import os

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms.v2 as tfs

# model = nn.Sequential(
#     nn.Conv2d(3, 32, 3, padding='same'),
#     nn.ReLU(),
#     nn.MaxPool2d(2),
#     nn.Conv2d(32, 16, 3, padding='same'),
#     nn.ReLU(),
#     nn.MaxPool2d(2),
#     nn.Conv2d(16, 8, 3, padding='same'),
#     nn.ReLU(),
#     nn.MaxPool2d(2),
#     nn.Conv2d(8, 4, 3, padding='same'),
#     nn.ReLU(),
#     nn.MaxPool2d(2),
#     nn.Flatten(),
#     nn.Linear(1024, 256),
#     nn.ReLU(),
#     nn.Linear(256, 2)
# )

model = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 8, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(8, 4, 3, padding='same'),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(4096, 128),
    nn.ReLU(),
    nn.Linear(128, 2)
)

path = 'dataset_reg/test/'
num_img = 100

st = torch.load('model_sun.tar', weights_only=False)
model.load_state_dict(st)

with open(os.path.join(path, "format.json"), "r") as fp:
    format = json.load(fp)

transforms = tfs.Compose([tfs.ToImage(), tfs.ToDtype(torch.float32, scale=True)])
img = Image.open(os.path.join(path, f'sun_reg_{num_img}.png')).convert('RGB')
img_t = transforms(img).unsqueeze(0)

model.eval()
predict = model(img_t)
print(predict)
print(tuple(format.values())[num_img-1])
p = predict.detach().squeeze().numpy()

plt.imshow(img)
plt.scatter(p[0], p[1], s=20, c='r')
plt.show()