In [None]:
import torchvision
import torchvision.transforms as transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_model_instance_segmentation(num_classes):
    # Загрузка предварительно обученной модели для сегментации экземпляров на COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # Замена головы классификатора на новую
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Теперь получаем количество входных признаков для маскировщика
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # И заменяем маскировщик на новый
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    return model

In [None]:
transformation = transforms.Compose(
    [    
        transforms.Resize((256, 256)),
        transforms.RandomRotation(20),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)

base_transform = torchvision.transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)

In [None]:
from torchvision.transforms import transforms
import psycopg2

conn_select = psycopg2.connect(
    database="hse_medical",
    user='hse_medical',
    password='123456',
    host='127.0.0.1',
    port='5450',
    options="-c search_path=analyze_medical"
)

conn_select.autocommit = True

transform = transforms.ToTensor()

def get_connection():
    return conn_select

In [None]:
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    def __getitem__(self, index):
        img = self.data[index]
        if self.transform:
            img = self.transform(img)
        label = self.labels[index]
        return img, label

    def __len__(self):
        return len(self.data)

In [None]:

cursor = conn_select.cursor()

sql1 = f'''select 
    target,
    image from medical_pictures_train;'''
cursor.execute(sql1)
data_postgres = cursor.fetchall()
print(f"datatrain_size : {len(data_postgres)}")
cursor.close()
targets = []
images = []
for data in data_postgres:
    targets.append(data[0])
    image = Image.open(io.BytesIO(data[1]))
    images.append(image)
    

data_train = CustomDataset(images, targets,transform=transformation)

In [None]:
import torch

num_classes = 23

model = get_model_instance_segmentation(num_classes)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

num_epochs = 10

data_loader = get_training_data()

for epoch in range(num_epochs):
    # Обучение за одну эпоху
    model.train()
    i = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # Обновление весов модели
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        i += 1
        if i % 50 == 0:
            print(f"Iteration #{i} loss: {losses.item()}")