In [9]:
from ultralytics import YOLO
import torch.nn as nn
import copy
import torch
from ultralytics.nn.modules import Concat, C2f, Conv, SPPF

In [10]:
pretrained_model = YOLO('yolov8m.pt').model
backbone = nn.Sequential(*list(pretrained_model.model.children())[:10])

In [11]:
class CustomBackbone(nn.Module):
    def __init__(self, layers, out_idx=[4, 6, 8]):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        self.out_idx = out_idx
        
    def forward(self, x):
        outputs = []
        for idx, layer in enumerate(self.layers):
            x = layer(x)
            if idx in self.out_idx:
                outputs.append(x)
        return outputs

In [12]:
backbone_rgb = CustomBackbone(backbone)
backbone_ir = copy.deepcopy(backbone_rgb)

# Modyfikacja pierwszej konwolucji dla IR (1 kanał)
backbone_ir.layers[0].conv = nn.Conv2d(1, 48, kernel_size=3, stride=2, padding=1, bias=False)

print(backbone_rgb.layers[0].conv.weight.shape)
print(backbone_ir.layers[0].conv.weight.shape)

torch.Size([48, 3, 3, 3])
torch.Size([48, 1, 3, 3])


In [13]:
class CustomNeck(nn.Module):
    def __init__(self, fused_channels):
        # super().__init__()
        # self.debugged = False
        # # fused_channels to liczba kanałów po konkatenacji, np. [192, 384, 576]
        # self.layer9 = SPPF(fused_channels[2], fused_channels[2] // 2)  # SPPF dla najgłębszej skali
        # self.layer10 = nn.Upsample(scale_factor=2, mode='nearest')
        # self.layer11 = Concat()
        # self.layer12 = C2f(fused_channels[2] // 2 + fused_channels[1], fused_channels[1] // 2, n=2)
        # self.layer13 = nn.Upsample(scale_factor=2, mode='nearest')
        # self.layer14 = Concat()
        # self.layer15 = C2f(fused_channels[1] // 2 + fused_channels[0], fused_channels[0] // 2, n=2)
        # self.layer16 = Conv(fused_channels[0] // 2, fused_channels[0] // 2, 3, s=2)
        # self.layer17 = Concat()
        # self.layer18 = C2f(fused_channels[0] // 2 + fused_channels[1], fused_channels[1] // 2, n=2)
        # self.layer19 = Conv(fused_channels[1] // 2, fused_channels[1] // 2, 3, s=2)
        # self.layer20 = Concat()
        # self.layer21 = C2f(fused_channels[1] // 2 + fused_channels[2], fused_channels[2] // 2, n=2)
        super().__init__()
        self.debugged = False
        self.layer9 = SPPF(fused_channels[2], fused_channels[2])  # 576 -> 576
        self.layer10 = nn.Upsample(scale_factor=2, mode='nearest')
        self.layer11 = Concat()
        self.layer12 = C2f(fused_channels[2] + fused_channels[1], fused_channels[1], n=2)  # 576 + 384 = 960 -> 384
        self.layer13 = nn.Upsample(scale_factor=2, mode='nearest')
        self.layer14 = Concat()
        self.layer15 = C2f(fused_channels[1] + fused_channels[0], fused_channels[0], n=2)  # 384 + 192 = 576 -> 192
        self.layer16 = Conv(fused_channels[0], fused_channels[0], 3, s=2)  # 192 -> 192
        self.layer17 = Concat()
        self.layer18 = C2f(fused_channels[0] + fused_channels[1], fused_channels[1], n=2)  # 192 + 384 = 576 -> 384
        self.layer19 = Conv(fused_channels[1], fused_channels[1], 3, s=2)  # 384 -> 384
        self.layer20 = Concat()
        self.layer21 = C2f(fused_channels[1] + fused_channels[2], fused_channels[2], n=2)  # 384 + 576 = 960 -> 576

    # def forward(self, fused):
    #     feat1, feat2, feat3 = fused  # feat1: warstwa 4, feat2: warstwa 6, feat3: warstwa 8

    #     x = self.layer9(feat3)
    #     x = self.layer10(x)
    #     x = self.layer11([x, feat2])
    #     x = self.layer12(x)
    #     x = self.layer13(x)
    #     x = self.layer14([x, feat1])
    #     feat_shallow = self.layer15(x)

    #     x = self.layer16(feat_shallow)
    #     x = self.layer17([x, feat2])
    #     feat_mid = self.layer18(x)
    #     x = self.layer19(feat_mid)
    #     x = self.layer20([x, feat3])
    #     feat_deep = self.layer21(x)
    #     print(f"feat_shallow channels: {feat_shallow.shape[1]}")
    #     print(f"feat_mid channels: {feat_mid.shape[1]}")
    #     print(f"feat_deep channels: {feat_deep.shape[1]}")
    #     return [feat_shallow, feat_mid, feat_deep]
    
    def forward(self, fused):
        feat1, feat2, feat3 = fused
        if not self.debugged:
            display(f"Input fused channels: {[f.shape[1] for f in fused]}")

        x = self.layer9(feat3)
        if not self.debugged:
            display(f"After layer9: {x.shape[1]} channels")

        x = self.layer10(x)
        x = self.layer11([x, feat2])
        if not self.debugged:
            display(f"After layer11: {x.shape[1]} channels")

        x = self.layer12(x)
        if not self.debugged:
            display(f"After layer12: {x.shape[1]} channels")

        x = self.layer13(x)
        x = self.layer14([x, feat1])
        feat_shallow = self.layer15(x)
        if not self.debugged:
            display(f"feat_shallow: {feat_shallow.shape[1]} channels")

        x = self.layer16(feat_shallow)
        x = self.layer17([x, feat2])
        feat_mid = self.layer18(x)
        if not self.debugged:
            display(f"feat_mid: {feat_mid.shape[1]} channels")

        x = self.layer19(feat_mid)
        x = self.layer20([x, feat3])
        feat_deep = self.layer21(x)
        if not self.debugged:
            display(f"feat_deep: {feat_deep.shape[1]} channels")
            self.debugged = True

        return [feat_shallow, feat_mid, feat_deep]

In [14]:
# Definicja CustomYOLO
class CustomYOLO(nn.Module):
    def __init__(self, pretrained_model, backbone_rgb, backbone_ir):
        super().__init__()
        self._init_attributes(pretrained_model)
        self.debugged = False
        self.model = nn.ModuleList([
            backbone_rgb,              
            backbone_ir,               
            CustomNeck(fused_channels=[192, 384, 576]),  
            copy.deepcopy(pretrained_model.model[-1])  
        ])
        # Warstwy redukujące kanały po fuzji
        self.reduce_channels = nn.ModuleList([
            nn.Conv2d(384, 192, kernel_size=1, stride=1, padding=0),  # Redukcja z 384 do 192
            nn.Conv2d(768, 384, kernel_size=1, stride=1, padding=0),  # Redukcja z 768 do 384
            nn.Conv2d(1152, 576, kernel_size=1, stride=1, padding=0)  # Redukcja z 1152 do 576
        ])
        
    def _init_attributes(self, pretrained_model):
        self.args = pretrained_model.args
        self.stride = pretrained_model.model[-1].stride
        self.anchors = pretrained_model.model[-1].anchors
        self.nc = pretrained_model.nc
        self.names = pretrained_model.names

    # def forward(self, x_rgb, x_ir):
    #     # Ekstrakcja cech
    #     features_rgb = self.model[0](x_rgb)  
    #     features_ir = self.model[1](x_ir)   
        
    #     fused = [torch.cat([f_rgb, f_ir], dim=1) for f_rgb, f_ir in zip(features_rgb, features_ir)]
        
    #     neck_outputs = self.model[2](fused)
    #     print(f"RGB features channels: {[f.shape[1] for f in features_rgb]}")
    #     print(f"IR features channels: {[f.shape[1] for f in features_ir]}")
    #     print(f"Fused features channels: {[f.shape[1] for f in fused]}")
    #     print(f"Neck outputs channels: {[n.shape[1] for n in neck_outputs]}")
    #     return self.model[3](neck_outputs)

    def forward(self, x_rgb, x_ir):
        if not self.debugged:
            display(f"Input RGB shape: {x_rgb.shape}")
            display(f"Input IR shape: {x_ir.shape}")

        features_rgb = self.model[0](x_rgb)
        features_ir = self.model[1](x_ir)

        if not self.debugged:
            display(f"RGB features channels: {[f.shape[1] for f in features_rgb]}")
            display(f"IR features channels: {[f.shape[1] for f in features_ir]}")

        fused = [torch.cat([f_rgb, f_ir], dim=1) for f_rgb, f_ir in zip(features_rgb, features_ir)]

        if not self.debugged:
            display(f"Fused features channels (before reduction): {[f.shape[1] for f in fused]}")

        fused = [self.reduce_channels[i](fused[i]) for i in range(len(fused))]

        if not self.debugged:
            display(f"Fused features channels (after reduction): {[f.shape[1] for f in fused]}")

        neck_outputs = self.model[2](fused)

        if not self.debugged:
            display(f"Neck outputs channels: {[n.shape[1] for n in neck_outputs]}")
            self.debugged = True

        return self.model[3](neck_outputs)

# Inicjalizacja modelu
custom_model = CustomYOLO(pretrained_model, backbone_rgb, backbone_ir)

In [15]:
from torch.utils.data import Dataset, DataLoader, default_collate
import cv2
import os
import torch

class MultimodalYOLODataset(Dataset):
    def __init__(self, rgb_dir, ir_dir, annotations_dir, img_size=(640, 640)):
        self.rgb_dir = rgb_dir
        self.ir_dir = ir_dir
        self.annotations_dir = annotations_dir
        self.img_files = sorted(os.listdir(rgb_dir))
        self.img_size = img_size

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_name = self.img_files[idx]
        rgb_path = os.path.join(self.rgb_dir, img_name)
        ir_path = os.path.join(self.ir_dir, img_name)
        label_path = os.path.join(self.annotations_dir, img_name.replace('.jpg', '.txt'))

        # Wczytanie obrazów
        img_rgb = cv2.imread(rgb_path)
        if img_rgb is None:
            raise FileNotFoundError(f"Nie można wczytać obrazu RGB: {rgb_path}")
        img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB)
        img_rgb = cv2.resize(img_rgb, self.img_size)  # Zmień rozmiar

        img_ir = cv2.imread(ir_path, cv2.IMREAD_GRAYSCALE)
        if img_ir is None:
            raise FileNotFoundError(f"Nie można wczytać obrazu IR: {ir_path}")
        img_ir = cv2.resize(img_ir, self.img_size)
        img_ir = img_ir[..., None]  # Dodanie wymiaru kanału: (H, W) -> (H, W, 1)

        # Konwersja na tensory i normalizacja
        img_rgb = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
        img_ir = torch.from_numpy(img_ir).permute(2, 0, 1).float() / 255.0

        # Wczytanie adnotacji
        if os.path.exists(label_path):
            with open(label_path, 'r') as f:
                labels = [list(map(float, line.split())) for line in f.readlines()]
            labels = torch.tensor(labels) if labels else torch.zeros((0, 5))
            if labels.numel() > 0:  # Dodaj kolumnę batch_idx (zostanie nadpisaną w collate_fn)
                labels = torch.cat([torch.zeros((labels.shape[0], 1)), labels], dim=1)
            else:
                labels = torch.zeros((0, 6))  # Pusty tensor z 6 kolumnami
        else:
            labels = torch.zeros((0, 6))

        return img_rgb, img_ir, labels

def custom_collate_fn(batch):
    img_rgb, img_ir, targets = zip(*batch)
    img_rgb = default_collate(img_rgb)
    img_ir = default_collate(img_ir)
    for i, target in enumerate(targets):
        if target.numel() > 0:  # Ustaw batch_idx
            target[:, 0] = i
    return img_rgb, img_ir, targets

# Utworzenie dataloadera
train_dataset = MultimodalYOLODataset(
    rgb_dir='LLVIP/visible/train',
    ir_dir='LLVIP/infrared/train',
    annotations_dir='LLVIP/Annotations',
    img_size=(640, 640)
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate_fn)

In [16]:
# Ustawienie urządzenia
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = custom_model.to(device)

debug_dataset = torch.utils.data.Subset(train_dataset, range(16))  # Tylko 16 próbek
debug_loader = DataLoader(debug_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

# Debugowanie na małym podzbiorze
model.eval()  # Wyłącz tryb treningowy, aby uniknąć aktualizacji wag
with torch.no_grad():  # Wyłącz obliczanie gradientów
    for batch in debug_loader:
        img_rgb, img_ir, targets = batch
        img_rgb = img_rgb.to(device)
        img_ir = img_ir.to(device)
        targets = [t.to(device) for t in targets]
        outputs = model(img_rgb, img_ir)
        break  # Zatrzymaj po pierwszej iteracji

'Input RGB shape: torch.Size([16, 3, 640, 640])'

'Input IR shape: torch.Size([16, 1, 640, 640])'

'RGB features channels: [192, 384, 576]'

'IR features channels: [192, 384, 576]'

'Fused features channels (before reduction): [384, 768, 1152]'

'Fused features channels (after reduction): [192, 384, 576]'

'Input fused channels: [192, 384, 576]'

'After layer9: 576 channels'

'After layer11: 960 channels'

'After layer12: 384 channels'

'feat_shallow: 192 channels'

'feat_mid: 384 channels'

'feat_deep: 576 channels'

'Neck outputs channels: [192, 384, 576]'

In [None]:
import torch
from ultralytics.utils.loss import v8DetectionLoss
from tqdm.notebook import tqdm  # Zmień na tqdm.notebook

# Ustawienie urządzenia
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Przygotowanie modelu
model = custom_model.to(device)
model.train()
loss_fn = v8DetectionLoss(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Pętla treningowa z lepszym outputem
num_epochs = 50
best_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0  # Licznik straty w epoce
    num_batches = 0   # Licznik batchy

    # Użyj tqdm.notebook do paska postępu
    progress_bar = tqdm(train_loader, desc=f"Epoka {epoch+1}/{num_epochs}", leave=True)
    
    for batch in progress_bar:
        img_rgb, img_ir, targets = batch
        img_rgb = img_rgb.to(device)
        img_ir = img_ir.to(device)
        targets = [t.to(device) for t in targets]
        
        optimizer.zero_grad()
        outputs = model(img_rgb, img_ir)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        
        # Oblicz średnią stratę w epoce
        epoch_loss += loss.item()
        num_batches += 1
        
        # Aktualizuj pasek postępu z bieżącą stratą
        progress_bar.set_postfix({'Strata': f"{loss.item():.4f}"})
    
    # Średnia strata w epoce
    avg_epoch_loss = epoch_loss / num_batches
    print(f"Epoka {epoch+1}/{num_epochs}, Średnia strata: {avg_epoch_loss:.4f}", flush=True)
    
    # Zapisz model, jeśli strata jest lepsza
    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss
        torch.save(model.state_dict(), 'best_model.pth')

Epoka 1/50:   0%|          | 0/752 [00:00<?, ?it/s]

TypeError: list indices must be integers or slices, not str

: 