In [1]:
import os
import numpy as np
import pandas as pd
import sklearn
import seaborn as sns
import geopandas as gpd
import rasterio
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from osgeo import gdal
from osgeo import ogr  
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tools.deeplab import DeepLabV3PlusMobileNetV3
from tools.unet import UNet
from tools.utils import merge_multiband_windowss, metrics
from tools.object_extractor import multigeojson_to_multichannel_mask, update_metadata_only
from tools.utils import plot_loss, optimizer_to
from tools.segformerB0 import SegFormerB0_12Channel
from tools.losses import FocalDiceLossBinary, AsymmetricLossOptimized
from torch.optim.swa_utils import get_ema_multi_avg_fn

In [2]:
main_path = ".\\dataset\\train"
obj_path = ".\\objects1.xlsx"

# Генерация или чтение объектов

In [None]:
generate_new = True
from tools.data_extractor import get_json
from tools.traintest_splitter import create_train_test

train_dataset, test_dataset = None, None

if generate_new:
    train_dataset, test_dataset, _, _, _, _ = create_train_test(main_path, seed=83)


In [4]:
from tools.data_extractor import get_json
from tools.test_splitter import create_test

In [5]:
#model = DeepLabV3PlusMobileNetV3(num_classes=1, pretrained_backbone_path='./mobilenet_v3_small-047dcff4.pth')
#CLASS_WEIGHTS = [4.0, 1.0, 2.0, 2.0,1.0, 2.0, 0.5,5.0,1.0,0.5]

In [11]:
model = SegFormerB0_12Channel(num_classes=1)

In [13]:
from PIL import Image
import numpy as np

img = Image.open('./dataset/train/002_ДЕМИДОВКА_FINAL/02_Демидовка_Li_карты/01_Демидовка_Lidar_c.tif')
pixels = np.array(img)
max_value = pixels.max()
max_value



np.uint8(252)

## Выделение окон

In [None]:
from shapely.geometry import box, mapping, Polygon, MultiPolygon
import json

dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
data_iter = iter(dataloader)
tile_size=2048
windows = []
shapes = []
photo_counter = 0
for epoch in range(len(dataloader)):
    model.train()
    running_loss = 0.0
    single_image_batch = next(data_iter)
    images = single_image_batch
    print(f"Количество обработанных изображений {photo_counter}")
    photo_counter += 1
    yest_curgani = False
    for batches in range(len(images["batches"])):
        res = images["batches"][batches]
        update_metadata_only(res[0][0], res[0][0], int(images["UTM"][0]))
        with rasterio.open(res[0][0]) as src:
            tile_id = 0
            stride = tile_size
            x_range = range(0, src.width + tile_size, stride)
            y_range = range(0, src.height + tile_size, stride)
            #print(x_range, y_range)
            for y in tqdm(y_range):
                for x in x_range:
                    window_bounds = (x, y, x+ tile_size, y + tile_size)
                    y_end = min(window_bounds[3], int(src.height))
                    x_end = min(window_bounds[2],  int(src.width))
                
                    # проверить если пересекаются для каждого полигона в файлах разметки
                    left, top = rasterio.transform.xy(src.transform, y, x)
                    right, bottom = rasterio.transform.xy(src.transform, y + tile_size, x + tile_size)
                    window_bbox = box(left, bottom, right, top)
                    crs = src.crs
                    # окно в системе изображения
                    # теперь фор по файлам и в них по полигонам
                    for path_raw in images["vector_mask"]:
                        path = path_raw[0]
                        fname_lower = os.path.basename(path).lower()
                        cur_class = None
                        cur_class_int = -1
                        if "курганы" in fname_lower:
                            cur_class = "курганы"
                            cur_class_int = 0
                        elif "дороги" in fname_lower:
                            cur_class = "дороги"
                            cur_class_int = 1
                        elif "фортификация" in fname_lower or "фортификации" in fname_lower:
                            cur_class = "фортификации"
                            cur_class_int = 2
                        elif "архитектура" in fname_lower:
                            cur_class = "архитектура"
                            cur_class_int = 3
                        elif "ямы" in fname_lower:
                            cur_class = "ямы"
                            cur_class_int = 4
                        elif "городище" in fname_lower or "городища" in fname_lower:
                            cur_class = "городище"
                            cur_class_int = 5
                        elif "иное" in fname_lower:
                            cur_class = "иное" 
                            cur_class_int = 6
                        elif "селище" in fname_lower:
                            cur_class = "селище" 
                            cur_class_int = 7
                        elif "пашня" in fname_lower or "пахота" in fname_lower:
                            cur_class = "пашня" 
                            cur_class_int = 8
                        elif "межа" in fname_lower:
                            cur_class = "межа"
                            cur_class_int = 9
                        else:
                            continue
                        try:
                            gdf = gpd.read_file(path)
                            if gdf.empty:
                                continue

                            if gdf.crs is None:
                                gdf = gdf.set_crs(crs)
                            elif gdf.crs != crs:
                                gdf1 = gdf.set_crs(src.crs,allow_override=True)
                                raster_bounds = src.bounds
                                raster_bbox = box(*raster_bounds)
                                geojson_bounds = gdf1.total_bounds
                                geojson_bbox = box(*geojson_bounds)
                                intersection = raster_bbox.intersects(geojson_bbox)
                                if intersection:
                                    gdf = gdf1
                                else:
                                    gdf = gdf.to_crs(crs)
                            if window_bbox is not None:
                                gdf = gdf[gdf.intersects(window_bbox)]
                                if gdf.empty:
                                    continue
                            cur_shape = {"geom": [], "pixel_coords": [x, y, x_end, y_end], "class_name": cur_class, "class_int": cur_class_int}
                            cur_shape |= images
                            for geom in gdf.geometry:
                                if geom.is_empty or not geom.is_valid:
                                    continue
                                if not (geom.intersects(window_bbox)):
                                    continue
                                if geom.geom_type == "MultiPolygon":
                                    for geom1 in geom.geoms:
                                        cur_shape["geom"].append(mapping(geom1))
                                else:
                                    cur_shape["geom"].append(mapping(geom))
                            if cur_shape["geom"]:
                                shapes.append(cur_shape)
                                
                                
                        except Exception as e:
                            print(f"Ошибка при обработке {path}: {e}")
                            continue
with open('thedata.json', 'w') as fp:
    json.dump(shapes, fp)

# Обучение

In [None]:
from rasterio.features import rasterize
import random
import json
from affine import Affine
import geopandas
from shapely.geometry import Polygon
from tools.utils import get_windowss, metrics_ae
import torch.nn.functional as F

tile_size=2048

optimizer = optim.AdamW(model.parameters(), 1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-4, 
    total_steps=1000,
    pct_start=0.1,
    anneal_strategy='cos'
)

def train_model(model, photo_counter, train_from_scratch=False):
    epochal_counter=0
    max_iou = 0
    random.seed(a=838383)
    counter_zero=0
    #weights =torch.tensor(CLASS_WEIGHTS)
    #weights_reshaped = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    #print(torch.tensor(CLASS_WEIGHTS).unsqueeze(0).unsqueeze(0).shape)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #weights_final = weights_reshaped.repeat(2, 1, 1, 1).to(device)
    criterion = FocalDiceLossBinary()
    if train_from_scratch:
        checkpoint= torch.load(".\\Segformer_b0_2048_1_.pth")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        optimizer_to(optimizer,device)
        epochal_counter = checkpoint['epoch']
        loss = checkpoint['loss']
        iou = checkpoint['iou']
    #ema_model = torch.optim.swa_utils.AveragedModel(model,device,multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
    model.to(device)
    image_tensors = []
    mask_tensors =[]
    train_loss = []
    batch_loss = 0.0
    counter=0
    counter_pashni=0
    counter_road=0
    with open ("./thedata.json","r") as js:
        json_data= json.load(js)
        for times in range(20):
            model.train()
            for obj in tqdm(json_data):
                running_loss = 0.0
                # print(f"Количество обработанных изображений {photo_counter}")
                photo_counter+=1
                yest_curgani = False
                for batches in range(len(obj["batches"])):
                    res = obj["batches"][batches]
                    if obj["class_int"] != 0:
                        break
                    counter += 1
                    update_metadata_only(res[0][0], res[0][0], int(obj["UTM"][0]))
                    with rasterio.open(res[0][0]) as src:
                        print(res[0])
                        window_bounds = obj['pixel_coords']
                        window_transform = src.transform * Affine.translation(window_bounds[0], window_bounds[1])
                        image = merge_multiband_windowss(res,window_bounds[0],window_bounds[1],tile_size)
                        x1 = window_bounds[0]
                        x2 = window_bounds[2]
                        y1 = window_bounds[1]
                        y2 = window_bounds[3]
                        if y2<y1 or x2<x1:
                            continue
                        w, h = x2 - x1, y2 - y1
                        print(w,h)
                        multimask = np.zeros((1, tile_size, tile_size), dtype=int)
                        # print(obj['geom'])
                        # print(obj['geom']['coordinates'])
                        x1 = []
                        for geom in obj["geom"]:
                            x1.append(Polygon(geom["coordinates"][0]))
                        print(obj["UTM"][0])
                        x2 = geopandas.GeoSeries(x1, crs="EPSG:" + str(obj["UTM"][0]))
                        multimask[:, 0 : 0 + h, 0 : 0 + w,] = rasterize(
                            x2,
                            out_shape=(h,w),
                            transform=window_transform,
                            fill=0,
                            dtype=int
                        )
                        rr = random.randint(0,3)
                        multimask  = torch.from_numpy(multimask)
                        print(multimask.shape)
                        image  = torch.from_numpy(image)
                        if rr == 1:
                            multimask = torch.rot90(multimask, k=1, dims=[-2, -1])
                            image  = torch.rot90(image, k=1, dims=[-2, -1])
                        elif rr == 2:
                            multimask = torch.rot90(multimask, k=2, dims=[-2, -1])
                            image  = torch.rot90(image, k=2, dims=[-2, -1])
                        elif rr == 3:
                            multimask = torch.rot90(multimask, k=3, dims=[-2, -1])
                            image  = torch.rot90(image, k=3, dims=[-2, -1])
                        image_tensors.append(image)
                        mask_tensors.append(multimask)
                        if counter %4 ==0 and counter != 0:
                            batch_img = torch.stack(image_tensors, dim=0)
                            batch_masks = torch.stack(mask_tensors,dim=0)
                            print(batch_masks.shape)
                            image_tensors = []
                            mask_tensors =[]
                            counter_zero=0
                            batch_img  = F.interpolate(batch_img, scale_factor=0.25, mode='area')
                            batch_masks = F.interpolate(batch_masks.float(), scale_factor=0.25, mode='nearest')
                            print(batch_img.shape)
                            print(batch_masks.shape)
                            batch_img = batch_img.float().to(device) 
                            batch_masks = batch_masks.to(device)
                            optimizer.zero_grad()
                            outputs = model(batch_img)
                            print(outputs.shape)
                            loss = criterion(outputs, batch_masks)  
                            loss.backward()
                            optimizer.step()
                            batch_loss += loss.item()
                            running_loss = batch_loss /4
                            pred_masks = (torch.sigmoid(outputs)).float()
                            with torch.no_grad():
                                train_loss.append(loss.cpu().item())
                            scheduler.step()
                            #ema_model.update_parameters(model)
                            iou, prec, rec,f1 = metrics(pred_masks,batch_masks,batch_img,0.5,device)
                            batch_loss = 0.0
                            counter=0
                            print(f'Epoch [{epochal_counter+1}], Loss: {running_loss:.4f},  mean IOU:[{iou}], Precision[{prec:.4f}], Recall [{rec:.4f}], F1_score [{f1:.4f}]')
                            epochal_counter +=1
                            plot_loss(train_loss)
                            print(res[0][0])
                            if iou>max_iou:
                                torch.cuda.empty_cache()
                                torch.save({
                                'epoch': epochal_counter,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'loss': loss,
                                'iou': iou,
                                'scheduler_state_dict': scheduler.state_dict(),
                                },  'Segformer_b0_2048_1_li_dem_maxiou.pth')
                                max_iou=iou
                            torch.save({
                            'epoch': epochal_counter,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': loss,
                            'iou': iou,
                            'scheduler_state_dict': scheduler.state_dict(),
                            },  'Segformer_b0_2048_1_li_dem.pth')
                            torch.cuda.empty_cache()
    return model

    
#model = UNet(n_channels=12, n_classes=1)
trained_model = train_model(model, 0, train_from_scratch=False)

# Обучение (без предварительного вычисления окон, но с аугментациями)

In [None]:
import torch.nn.functional as F
import random

tile_size=2048
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer,
#     max_lr=1e-4, 
#     total_steps=1000,
#     pct_start=0.05,
#     anneal_strategy='cos'
# )


def train_model(model, photo_counter, train_from_scratch=False):
    epochal_counter=0
    max_iou = 0
    random.seed(a=838383)
    counter_zero=0
    #weights =torch.tensor(CLASS_WEIGHTS)
    #weights_reshaped = weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
    #print(torch.tensor(CLASS_WEIGHTS).unsqueeze(0).unsqueeze(0).shape)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #weights_final = weights_reshaped.repeat(2, 1, 1, 1).to(device)
    criterion = FocalDiceLossBinary()

    if train_from_scratch:
        checkpoint= torch.load(".\\Segformer_b0_2048_1.pth")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        optimizer_to(optimizer,device)
        epochal_counter = checkpoint['epoch']
        loss = checkpoint['loss']
        iou = checkpoint['iou']
    # ema_model = torch.optim.swa_utils.AveragedModel(model,device,multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
    model.to(device)
    image_tensors = []
    mask_tensors =[]
    train_loss = []
    batch_loss = 0.0
    counter = 0
    dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
    for times in range(20):
        data_iter = iter(dataloader)
        for epoch in range(len(dataloader)):
            model.train()
            running_loss = 0.0
            images = next(data_iter)
            print(f"Количество обработанных изображений {photo_counter}")
            photo_counter+=1
            yest_curgani = False
            for batches in range(len(images["batches"])):
                res = images["batches"][batches]
                for name in images["vector_mask"]:
                    print(name)
                    if "курганы" in name[0]:
                        yest_curgani = True
                if not yest_curgani:
                    print("нет_курганов,",res[0][0])
                    break
                update_metadata_only(res[0][0], res[0][0], int(images["UTM"][0]))
                with rasterio.open(res[0][0]) as src:
                    tile_id = 0
                    stride = tile_size
                    x_range = range(0, src.width + tile_size, stride)
                    y_range = range(0, src.height + tile_size, stride)
                    factor_stride_x = random.randint(0, tile_size//2)
                    factor_stride_y = random.randint(0, tile_size//2)
                    for y in tqdm(y_range):
                        for x in x_range:
                            window_bounds = (x+factor_stride_x, y+factor_stride_y, x+factor_stride_x+ tile_size, y+factor_stride_y + tile_size)
                            y_end = min(window_bounds[3], int(src.height))
                            x_end = min(window_bounds[2],  int(src.width))
                            if (x + factor_stride_x < x_end and y + factor_stride_y < y_end):
                                image = merge_multiband_windowss(res,x+factor_stride_x,y+factor_stride_y,tile_size)
                                multimask = multigeojson_to_multichannel_mask(images["vector_mask"],src, window_bounds)[0]
                                if  len(np.nonzero(multimask[0])[0])<100:
                                    if counter_zero>=1:
                                        continue
                                    else:
                                        counter_zero+=1

                                counter += 1
                                rr = random.randint(0,3)
                                multimask  = torch.from_numpy(multimask)
                                image  = torch.from_numpy(image)
                                if rr == 1:
                                    multimask = torch.rot90(multimask, k=1, dims=[-2, -1])
                                    image  = torch.rot90(image, k=1, dims=[-2, -1])
                                elif rr == 2:
                                    multimask = torch.rot90(multimask, k=2, dims=[-2, -1])
                                    image  = torch.rot90(image, k=2, dims=[-2, -1])
                                elif rr == 3:
                                    multimask = torch.rot90(multimask, k=3, dims=[-2, -1])
                                    image  = torch.rot90(image, k=3, dims=[-2, -1])
                                image_tensors.append(image)
                                mask_tensors.append(multimask)
                                if counter % 4 == 0 and counter != 0:
                                    batch_img = torch.stack(image_tensors, dim=0)
                                    batch_masks = torch.stack(mask_tensors,dim=0)
                                    batch_masks = batch_masks.unsqueeze(0)
                                    batch_masks = torch.permute(batch_masks,(1,0,2,3))
                                    image_tensors = []
                                    mask_tensors =[]
                                    counter_zero=0
                                    batch_img  = F.interpolate(batch_img, scale_factor=0.25, mode='area')
                                    batch_masks = F.interpolate(batch_masks.float(), scale_factor=0.25, mode='nearest')
                                    print(batch_img.shape)
                                    print(batch_masks.shape)
                                    batch_img = batch_img.float().to(device) 
                                    batch_masks = batch_masks.to(device)
                                    for sub_ep in range(50):
                                        outputs = model(batch_img)
                                        print(outputs.shape)
                                        loss = criterion(outputs, batch_masks)  
                                        optimizer.zero_grad()
                                        loss.backward()
                                        optimizer.step()
                                        batch_loss = loss.item()
                                        running_loss = batch_loss /4
                                        pred_masks = (torch.sigmoid(outputs)).float()
                                        with torch.no_grad():
                                            train_loss.append(batch_loss)
                                        # scheduler.step()
                                        # ema_model.update_parameters(model)
                                        iou, prec, rec,f1 = metrics(pred_masks,batch_masks,batch_img,0.5,device)
                                        counter = 0
                                        if sub_ep % 10 == 0:
                                            print(f'Epoch [{epochal_counter+1}], Loss: {batch_loss:.4f},  mean IOU:[{iou}], Precision[{prec:.4f}], Recall [{rec:.4f}], F1_score [{f1:.4f}]')
                                            epochal_counter +=1 
                                            plot_loss(train_loss)
                                            print(res[0][0])
                                        batch_loss = 0.0
                                        if iou > max_iou:
                                            torch.cuda.empty_cache()
                                            torch.save({
                                            'epoch': epochal_counter,
                                            'model_state_dict': model.state_dict(),
                                            'optimizer_state_dict': optimizer.state_dict(),
                                            'loss': loss,
                                            'iou': iou,
                                                },  'Segformer_b0_2048_11_maxiou.pth')
                                            max_iou = iou
                                        torch.save({
                                        'epoch': epochal_counter,
                                        'model_state_dict': model.state_dict(),
                                        'optimizer_state_dict': optimizer.state_dict(),
                                        'loss': loss,
                                        'iou': iou,
                                            },  'Segformer_b0_2048_11.pth')
                                        torch.cuda.empty_cache()
    return model

#model = UNet(n_channels=12, n_classes=1)
trained_model = train_model(model, 0, train_from_scratch=False)

# Обучение Классификатора

In [14]:
from tools.utils import metrics_det

In [15]:
from tools.Mob_net_classificator import MobileNetV3Small
model = MobileNetV3Small(num_classes=11)

In [None]:
tile_size=2048
import random
import torch.nn.functional as F
def train_det_model(model, photo_counter, train_from_scratch=False):
    train_loss = []
    epochal_counter = 0
    max_f1 = 0
    random.seed(a=838383)
    counter_zero=0
    optimizer = optim.AdamW(model.parameters(), lr=1e-6)
    criterion = nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
    if train_from_scratch:
        checkpoint = torch.load(".\\Mob_net_class_.pth")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        optimizer_to(optimizer,device)
        epochal_counter = checkpoint['epoch']
        loss = checkpoint['loss']
        f1 = checkpoint['f1']
        prec = checkpoint['prec']
        rec = checkpoint['rec']
    model.to(device)
    image_tensors = []
    logits_tensors =[]
    batch_loss = 0.0
    counter = 0
    counter_road =0
    dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
    for times in range(20):
        data_iter = iter(dataloader)
        for epoch in range(len(dataloader)):
            model.train()
            running_loss = 0.0
            single_image_batch = next(data_iter)
            images = single_image_batch
            print(f"Количество обработанных изображений {photo_counter}")
            photo_counter+=1
            yest_curgani = False
            for batches in range(len(images["batches"])):
                res = images["batches"][batches]
                update_metadata_only(res[0][0], res[0][0], int(images["UTM"][0]))
                with rasterio.open(res[0][0]) as src:
                    tile_id = 0
                    stride = tile_size
                    x_range = range(0, src.width + tile_size, stride)
                    y_range = range(0, src.height + tile_size, stride)
                    factor_stride_x = random.randint(0, tile_size)
                    factor_stride_y = random.randint(0, tile_size)
                    for y in tqdm(y_range):
                        for x in x_range:
                            window_bounds = (x+factor_stride_x, y+factor_stride_y, x+factor_stride_x+ tile_size, y+factor_stride_y + tile_size)
                            y_end = min(window_bounds[3], int(src.height))
                            x_end = min(window_bounds[2],  int(src.width))
                            if (x + factor_stride_x < x_end and y + factor_stride_y < y_end):
                                image = merge_multiband_windowss(res,x+factor_stride_x,y+factor_stride_y,tile_size)
                                multimask = multigeojson_to_multichannel_mask(images["vector_mask"],src, window_bounds)
                                if  len(np.nonzero(multimask[0])[0])>100:
                                    class_ = 1
                                elif len(np.nonzero(multimask[1])[0])>100:
                                    if counter_road<3:
                                        class_ = 2
                                        counter_road+=1
                                    else:
                                         continue
                                elif len(np.nonzero(multimask[2])[0])>100:
                                    class_ = 3
                                elif len(np.nonzero(multimask[3])[0])>100:
                                    class_ = 4
                                elif len(np.nonzero(multimask[4])[0])>100:
                                    class_ = 5
                                elif len(np.nonzero(multimask[5])[0])>100:
                                    class_ = 6
                                elif len(np.nonzero(multimask[6])[0])>100:
                                    class_ = 7
                                elif len(np.nonzero(multimask[7])[0])>100:
                                    class_ = 8
                                elif len(np.nonzero(multimask[8])[0])>100:
                                    class_ = 9
                                elif len(np.nonzero(multimask[9])[0])>100:
                                    class_ = 10
                                else:
                                    if counter_zero<6:
                                        counter_zero += 1
                                        class_ = 0
                                    else:
                                        continue
                                #one_hot_class = F.one_hot(torch.tensor(class_), num_classes = 11).float()
                                counter+=1
                                rr = random.randint(0,3)
                                image  = torch.from_numpy(image)
                                if rr == 1:
                                    image  = torch.rot90(image, k=1, dims=[-2, -1])
                                elif rr == 2:
                                    image  = torch.rot90(image, k=2, dims=[-2, -1])
                                elif rr == 3:
                                    image  = torch.rot90(image, k=3, dims=[-2, -1])
                                image_tensors.append(image)
                                logits_tensors.append(class_)
                                if counter %16 ==0 and counter != 0:
                                    batch_img = torch.stack(image_tensors, dim=0)
                                    batch_masks = torch.tensor(logits_tensors)
                                    image_tensors = []
                                    logits_tensors =[]
                                    counter_zero=0
                                    batch_img  = F.interpolate(batch_img, scale_factor=0.25, mode='area')
                                    print(batch_img.shape)
                                    print(batch_masks.shape)
                                    batch_img = batch_img.float().to(device) 
                                    batch_masks = batch_masks.to(device)
                                    outputs = model(batch_img)
                                    print(outputs.shape)
                                    loss = criterion(outputs, batch_masks)
                                    optimizer.zero_grad()
                                    loss.backward()
                                    optimizer.step()
                                    batch_loss += loss.item()
                                    running_loss = batch_loss /16
                                    with torch.no_grad():
                                        train_loss.append(loss.cpu().item())
                                    prec, rec,f1 = metrics_det(outputs,batch_masks,batch_img, device)
                                    batch_loss = 0.0
                                    counter = 0
                                    counter_road = 0
                                    if  epochal_counter!=0 and epochal_counter%10==0:
                                        plot_loss(train_loss)
                                    print(f'Epoch [{epochal_counter+1}], Loss: {running_loss:.4f}, Precision[{prec:.4f}], Recall [{rec:.4f}], F1_score [{f1:.4f}]')
                                    epochal_counter +=1
                                    if f1>max_f1:
                                        torch.cuda.empty_cache()
                                        torch.save({
                                        'epoch': epochal_counter,
                                        'model_state_dict': model.state_dict(),
                                        'optimizer_state_dict': optimizer.state_dict(),
                                        'loss': loss,
                                        'f1':f1,
                                        'prec':prec,
                                        'rec':rec
                                          },  'Mob_net_class_maxf1.pth')
                                        max_f1=f1
                                    torch.save({
                                        'epoch': epochal_counter,
                                        'model_state_dict': model.state_dict(),
                                        'optimizer_state_dict': optimizer.state_dict(),
                                        'loss': loss,
                                        'f1':f1,
                                        'prec':prec,
                                        'rec':rec
                                          },  'Mob_net_class_.pth')
                                    torch.cuda.empty_cache()
    return model

trained_model = train_det_model(model, 0, train_from_scratch=False)

# Тестирование

In [16]:
from tools.Mob_net_classificator import MobileNetV3Small
vis_model = MobileNetV3Small(num_classes=11)

In [None]:
import torch.nn.functional as F
import shapely
from shapely.geometry import MultiPolygon, Polygon, mapping, shape
from tools.test_module import mask_to_polygons, apply_transforms
from tools.utils import metrics_test
import json

tile_size = 2048
# трансформ из окна в изображение
# трансформ изображения в координаты
# трансформ из коорд в epsg:3587

def test_model(model, photo_counter):
    feature_cnt = 0
    result = {"type": "FeatureCollection", "features": []}
    max_iou = 0
    checkpoint= torch.load("Segformer_b0_2048_1_kurg.pth", map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    loss = checkpoint['loss']
    iou = checkpoint['iou']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   
    model.to(device)
    checkpoint2 =torch.load("Mob_net_class_.pth", map_location='cpu')
    vis_model.load_state_dict(checkpoint2['model_state_dict'])
    vis_model.to(device)
    counter=0
    epochal_counter = 0
    dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
    data_iter = iter(dataloader)
    to_compare = {"type": "FeatureCollection", "features": []}
    for epoch in range(len(dataloader)):
        torch.cuda.empty_cache()
        model.eval()
        #vis_model.eval()
        single_image_batch = next(data_iter)
        images = single_image_batch
        cur_flag = True
        polygons = []
        for batches in range(len(images["batches"])):
            res = images["batches"][batches]
            update_metadata_only(res[0][0], res[0][0], int(images["UTM"][0]))  
            with rasterio.open(res[0][0]) as src:
                if not cur_flag:
                    break
                cur_flag = False
                for name in images["vector_mask"]:
                    if "курганы" in name[0]:
                        cur_flag=name[0]
                if not cur_flag:
                    print("нет_курганы,",res[0][0])
                    break
                counter = 0
                photo_counter+=1
                print(f"Количество обработанных изображений {photo_counter}")
                tile_id = 0
                stride = tile_size//2
                x_range = range(0, src.width + tile_size, stride)
                y_range = range(0, src.height + tile_size, stride)
                for y in tqdm(y_range):
                    for x in x_range:
                        window_bounds = (x, y, x + tile_size, y + tile_size)
                        y_end = min(window_bounds[3], int(src.height))
                        x_end = min(window_bounds[2],  int(src.width))
                        if (x<x_end and y<y_end):
                            image = merge_multiband_windowss(res,x,y,tile_size)
                            multimask = multigeojson_to_multichannel_mask(images["vector_mask"],src, window_bounds)[0]
                            multi = multimask
                            multimask  = torch.from_numpy(multimask)
                            image  = torch.from_numpy(image)
                            image  =  image.unsqueeze(0)
                            image = F.interpolate(image , scale_factor=0.25, mode='area')
                            multimask  = multimask.unsqueeze(0)
                            multimask = multimask.unsqueeze(0)
                            multimask =F.interpolate(multimask.float(), scale_factor=0.25, mode='nearest').int()
                            print(image.shape)
                            print(multimask.shape)
                            image = image.float().to(device)
                            #eye =  torch.softmax(vis_model(image),dim =1)
                            #print(eye[0][1])
                            #if eye[0][1]<0.05:
                                #continue
                            threshhold = 0.7
                            outputs = model(image)
                            multimask = multimask.to(device)
                            outputs = torch.sigmoid(outputs)
                            iou, prec, rec,f1 = metrics_test(outputs,multimask,image, threshhold,device)
                            outputs = F.interpolate(outputs, scale_factor=4, mode='nearest')
                            counter+=1
                            epochal_counter +=1
                            outputs = outputs.detach().cpu().numpy()
                            for polygon in mask_to_polygons(outputs[0][0], x,y, threshhold):#МОЖЕТ СЛОМАТЬСЯ ИЗ_ЗА ФОРМАТА РАСТЕРИО ХРАНЕНИЯ ИЗОБРАЖЕНИЙ (h,w)):
                                polygon = apply_transforms(polygon, "EPSG:"+ images["UTM"][0],src.transform)
                                polygons.append(polygon)
                            torch.cuda.empty_cache()
        mask_names = ["kurgany","dorogi","fortifikatsii","arkhitektury","yamy","gorodishche","inoe","selishcha","pashni","mezha"]
        if cur_flag is not False:
            # curflag = разметка текущая
            polygons = list(shapely.union_all(polygons).normalize().geoms)
            print(f'Image [{res[0][0]}]')
            feature_reg_name = images["region_name"][0]
            feature_subreg_name = "" if images["region_name"][0] == images["sub_region_name"][0] else images["sub_region_name"][0]
            feature_subreg_name = feature_subreg_name
            # классы внутри цикла по классам
            feature_markup_type = images["markup_type"][0]
            feature_original_crs = "urn:ogc:def:crs:EPSG::" + images["UTM"][0]
            feature_crs = "urn:ogc:def:crs:EPSG::3857"
# ---------- преобразование их разметки в формат нужный
            feature_class_name = "kurgany" #КУРГАНЫ ONLY
            with open(cur_flag) as fff:
                d = json.load(fff)
                feature_cnt = 0
                for feature in d["features"]:
                    for polyg_raw in list(feature["geometry"]["coordinates"]):
                        polyg = None
                        if "ЛИХУША" in feature_reg_name:
                            polyg = polyg_raw[0]
                        else:
                            polyg = polyg_raw
                        polyg_tuple = (tuple(x) for x in polyg)
                        polyg_sh = Polygon(polyg_tuple)
                        cur_feature = {}
                        cur_feature["type"] = "Feature"
                        cur_feature["properties"] = {
                            "class_name": feature_class_name,
                            "region_name": feature_reg_name,
                            "sub_region_name": feature_subreg_name,
                            "markup_type": feature_markup_type,
                            "original_crs": feature_original_crs,
                            "crs": feature_crs,
                            "fid": feature_cnt
                        }
                        feature_cnt += 1
                        cur_feature["geometry"] =  mapping(polyg_sh)
                        to_compare["features"].append(cur_feature)
            # ----------- результат модели записываем
            feature_cnt = 0
            for class_idx in range(1):
                for polygon in polygons:
                    feature_class_name = mask_names[class_idx]
                    feature = {}
                    feature["type"] = "Feature"
                    feature["properties"] = {
                        "class_name": feature_class_name,
                        "region_name": feature_reg_name,
                        "sub_region_name": feature_subreg_name,
                        "markup_type": feature_markup_type,
                        "original_crs": feature_original_crs,
                        "crs": feature_crs,
                        "fid": feature_cnt
                    }
                    feature_cnt += 1
                    feature["geometry"] =  mapping(polygon)
                    
                    result["features"].append(feature)
            output_file =".\\result1.geojson"
            with open(output_file,"w", encoding="utf-8") as handle:
                json.dump(result, handle, ensure_ascii=False, indent=2)
            with open("to_compare.geojson","w", encoding="utf-8") as handle:
                json.dump(to_compare, handle, ensure_ascii=False, indent=2)
            %run ./tools/compute_metrics_qual.py --predictions result1.geojson  --ground-truth to_compare.geojson

test_model(model, 0)                                 