In [None]:
import pandas as pd
from torch.utils.data import Dataset
import PIL
import torch.nn as nn
import os
from transformers import AutoImageProcessor, Swinv2ForImageClassification
import torch
import matplotlib.pyplot as plt
import numpy as np





In [None]:
model_name = "microsoft/swinv2-base-patch4-window8-256"
device = "cuda" if  torch.cuda.is_available else "cpu"
filename_col = 'image_name'
target_col = 'target' # для одинаковых датасетов с обычным обучением

In [3]:
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')

In [4]:
preprocessor = AutoImageProcessor.from_pretrained(model_name)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
class ISICDataset_train(Dataset):
    def __init__(self, image_dir, df, preprocessor, transform=None):
        self.image_dir = image_dir
        self.preprocessor = preprocessor
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __getitem__(self, idx):
        image_name = self.df.iloc[idx][filename_col] + '.jpg'
        image_path = os.path.join(self.image_dir, image_name)

        with PIL.Image.open(image_path) as img:
            image = img.convert("RGB")

        if self.transform:
            augmented = self.transform(image=np.array(image))
            image = augmented['image']
            if isinstance(image, torch.Tensor):
                image = image.permute(1, 2, 0).cpu().numpy()

        inputs = self.preprocessor(image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)

        label = int(self.df.iloc[idx][target_col])
        return {"pixel_values": pixel_values, "labels": label}

    def __len__(self):
        return len(self.df)
    
    def show_image(self, idx):
        image_name = self.df.iloc[idx][filename_col] + '.jpg'
        image_path = os.path.join(self.image_dir, image_name)

        with PIL.Image.open(image_path) as img:
            image = img.convert("RGB")

        if self.transform:
            augmented = self.transform(image=np.array(image))
            image = augmented["image"]
        else:
            image = np.array(image)

        if isinstance(image, torch.Tensor):
            image = image.permute(1, 2, 0).cpu().numpy()

        plt.imshow(image)
        plt.axis('off')
        plt.title(f"Index: {idx}, Label: {self.df.iloc[idx][target_col]}")
        plt.show()



In [6]:
base_model = Swinv2ForImageClassification.from_pretrained(
    "swin2-base/checkpoint-1230",
    num_labels=2,
    use_safetensors=True,
    ignore_mismatched_sizes=True
)

In [None]:
# Другой вариант загрузки из чекпоинта

# from safetensors.torch import load_file

# state_dict = load_file("swin2-base-3/checkpoint-336/model.safetensors")
# model.load_state_dict(state_dict)

In [None]:
class SwinWrapper(nn.Module):
    def __init__(self, backbone_model: nn.Module, num_labels=2):
        super().__init__()
        # Основная фича-сеть без классификационной головы
        self.model = backbone_model
        # Название backbone модуля можно получить print(model)
        self.backbone = self.model.swinv2
        self.pool = nn.AdaptiveAvgPool1d(1)

    @torch.no_grad()
    def get_image_features(self, pixel_values):
        # pixel_values: (B, 3, H, W)
        outputs = self.backbone(pixel_values)
        # outputs.last_hidden_state: (B, N, C)
        # Берем cls-токен
        cls_emb = outputs.last_hidden_state[:, 0]  # (B, C)
        return cls_emb

    def forward(self, pixel_values):
        return self.model(pixel_values)

In [8]:
model = SwinWrapper(base_model)

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

# Датасеты без аугментаций
train_dataset = ISICDataset_train("train", train, preprocessor, transform=None)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=0)

model.eval()
model.to(device)

def extract_embeddings(loader, df):
    all_embs = []
    all_names = []

    for batch_idx, batch in tqdm(enumerate(loader), total=len(loader)):
        pixel_values = batch["pixel_values"].to(device)
        with torch.no_grad():
            emb = model.get_image_features(pixel_values)   # (B, C)

        emb = emb.cpu().numpy()
        all_embs.append(emb)

        # берем image_name из df по индексам батча
        start = batch_idx * loader.batch_size
        end = start + emb.shape[0]
        names = df.iloc[start:end][filename_col].tolist()
        all_names.extend(names)

    all_embs = np.vstack(all_embs)
    emb_df = pd.DataFrame(all_embs)
    emb_df.insert(0, filename_col, all_names)
    return emb_df

train_emb_df = extract_embeddings(train_loader, train)

100%|██████████| 518/518 [1:17:32<00:00,  8.98s/it]
  0%|          | 0/172 [00:00<?, ?it/s]


KeyError: 'target'

In [10]:
train_emb_df.to_csv("train_embeddings_swinv2.csv", index=False)

In [None]:
test[target_col] = 0
test_dataset  = ISICDataset_train("test",  test,  preprocessor, transform=None)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False, num_workers=0)
test_emb_df  = extract_embeddings(test_loader, test)
test_emb_df.to_csv("test_embeddings_swinv2.csv", index=False)

100%|██████████| 172/172 [22:47<00:00,  7.95s/it]
