In [1]:
import os
import torch
import albumentations as A
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import BeitForImageClassification
from torch.optim import AdamW
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import accuracy_score

from transformers import BeitImageProcessor
import requests
import tqdm

In [2]:
train_path = "split_dataset_masha/train"
val_path = "split_dataset_masha/val"
test_path = "split_dataset_masha/test"

In [3]:
# Конфигурация
config = {
    "model_name": "microsoft/beit-large-patch16-224-pt22k-ft22k",
    "data_path": "/path/to/your/data",  # Папка с train/val/test подпапками
    "img_size": 224,
    "batch_size": 64,
    "epochs": 5,
    "lr": 3e-5,
    "num_workers": 4,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

In [4]:
# Аугментации с Albumentations
train_transform = A.Compose([
    A.Resize(config['img_size'], config['img_size']),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.Rotate(limit=20, p=0.3),
    A.RandomBrightnessContrast(p=0.3),
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # BEiT-specific normalization
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(config['img_size'], config['img_size']),
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ToTensorV2()
])

test_transform = A.Compose([
    A.Resize(config['img_size'], config['img_size']),
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ToTensorV2()
])

In [5]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
        self.get_class = {i: cls for i, cls in enumerate(self.classes)}
        self.samples = []
        self.transform = transform

        for cls in self.classes:
            cls_dir = os.path.join(root_dir, cls)
            for img_name in os.listdir(cls_dir):
                self.samples.append((
                    os.path.join(cls_dir, img_name),
                    self.class_to_idx[cls]
                ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]

        return image, torch.tensor(label, dtype=torch.long)

In [6]:
# Инициализация данных
train_dataset = CustomDataset(
    train_path,
    transform=train_transform
)
val_dataset = CustomDataset(
    val_path,
    transform=val_transform
)
test_dataset = CustomDataset(
    test_path,
    transform=test_transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config['num_workers']
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers']
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config['num_workers']
)

In [7]:
loss_fn = torch.nn.CrossEntropyLoss()

In [8]:
model = BeitForImageClassification.from_pretrained(
    config['model_name'],
    num_labels=len(train_dataset.classes),
    ignore_mismatched_sizes=True
).to(config["device"])

Some weights of BeitForImageClassification were not initialized from the model checkpoint at microsoft/beit-large-patch16-224-pt22k-ft22k and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([21841, 1024]) in the checkpoint and torch.Size([144, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([21841]) in the checkpoint and torch.Size([144]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
optimizer = AdamW(model.parameters(), lr=config['lr'])

In [None]:
def train_model(model, train_loader, val_loader, optimizer, config):
    device = config['device']
    for epoch in range(config['epochs']):
        # Training
        model.train()
        train_loss = 0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(pixel_values=images, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.cuda()
                labels = labels.cuda()

                outputs = model(pixel_values=images, labels=labels)
                val_loss += outputs.loss.item()

                preds = torch.argmax(outputs.logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Метрики
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        val_acc = accuracy_score(all_labels, all_preds)

        print(f"Epoch {epoch+1}/{config['epochs']}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

In [24]:
# torch.save(model.state_dict(), "fine_tuned_beit.pth")

In [6]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = BeitImageProcessor.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')
model = BeitForImageClassification.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')

inputs = processor(images=image, return_tensors="pt")

outputs = model(**inputs)
logits = outputs.logits

# model predicts one of the 21,841 ImageNet-22k classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

preprocessor_config.json:   0%|          | 0.00/276 [00:00<?, ?B/s]

  return func(*args, **kwargs)


config.json:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.31G [00:00<?, ?B/s]

Predicted class: kitten, kitty
