In [1]:
import os
import random
from glob import glob
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

try:
    import faiss
except ImportError:
    faiss = None

SEED = 42
IMG_SIZE = 224
BATCH_SIZE = 64
EPOCHS = 100
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(SEED)
torch.manual_seed(SEED)

class SimCLRDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        xi = self.transform(img)
        xj = self.transform(img)
        return xi, xj

simclr_transform = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.2, 1.0)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(0.8, 0.8, 0.8, 0.2),
    T.RandomGrayscale(p=0.2),
    T.GaussianBlur(kernel_size=9),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=2048, hidden_dim=512, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x):
        return nn.functional.normalize(self.net(x), dim=1)

class SimCLR(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.backbone.fc = nn.Identity()
        # Заморозка первых 4 блоков 
        for param in list(self.backbone.parameters())[:6]:
            param.requires_grad = False
        self.projector = ProjectionHead(2048)
    def forward(self, x):
        h = self.backbone(x)
        z = self.projector(h)
        return z

def nt_xent_loss(z_i, z_j, temperature=0.1):
    z = torch.cat([z_i, z_j], dim=0)
    z = nn.functional.normalize(z, dim=1)
    sim_matrix = torch.matmul(z, z.T) / temperature
    batch_size = z_i.size(0)
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(DEVICE)
    sim_matrix = sim_matrix.masked_fill(mask, -float('inf'))
    positives = torch.cat([torch.arange(batch_size, 2 * batch_size), torch.arange(0, batch_size)]).to(DEVICE)
    sim_probs = nn.functional.log_softmax(sim_matrix, dim=1)
    loss = -sim_probs[torch.arange(2 * batch_size), positives]
    return loss.mean()

def train_simclr(model, dataloader, optimizer, scheduler=None):
    model.train()
    for epoch in range(EPOCHS):
        total_loss = 0
        for xi, xj in tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            xi, xj = xi.to(DEVICE), xj.to(DEVICE)
            zi, zj = model(xi), model(xj)
            loss = nt_xent_loss(zi, zj)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if scheduler is not None:
            scheduler.step()
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(dataloader):.4f}")

all_images = glob('/kaggle/input/cow-f-crop/cow_f_crop/*.jpg')
print(f"картинок: {len(all_images)}")

dataset = SimCLRDataset(all_images, simclr_transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

model = SimCLR().to(DEVICE)
LR = 0.3 * BATCH_SIZE / 256
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

train_simclr(model, dataloader, optimizer, scheduler)

def compute_embeddings(model, paths):
    model.eval()
    embeddings = {}
    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(IMG_SIZE),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    with torch.no_grad():
        for path in tqdm(paths, desc="Embedding"):
            img = Image.open(path).convert('RGB')
            img_tensor = transform(img).unsqueeze(0).to(DEVICE)
            emb = model.backbone(img_tensor).squeeze().cpu().numpy()
            embeddings[os.path.basename(path)] = emb / np.linalg.norm(emb)
    return embeddings

embeddings = compute_embeddings(model, all_images)
np.save("cow_embeddings.npy", embeddings)
print("Эмбеддинги сохранены")

def find_similar(query_img_path, embeddings, top_k=5):
    query_name = os.path.basename(query_img_path)
    query_vector = embeddings[query_name].reshape(1, -1)
    names = list(embeddings.keys())
    all_vectors = np.stack([embeddings[name] for name in names])

    if faiss is not None:
        index = faiss.IndexFlatIP(all_vectors.shape[1])
        index.add(all_vectors.astype('float32'))
        D, I = index.search(query_vector.astype('float32'), top_k+1)
        print(f"\nТоп {top_k} похожих с: {query_name}\n")
        for i, idx in enumerate(I[0][1:top_k+1]):
            print(f"{names[idx]} (sim: {D[0][i+1]:.4f})")
    else:
        sims = cosine_similarity(query_vector, all_vectors)[0]
        top_idxs = sims.argsort()[::-1][1:top_k+1]
        print(f"\nТоп {top_k} похожих с: {query_name}\n")
        for idx in top_idxs:
            print(f"{names[idx]} (sim: {sims[idx]:.4f})")

картинок: 315


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 201MB/s]
Epoch 1/100: 100%|██████████| 4/4 [00:07<00:00,  1.89s/it]


Epoch 1/100, Loss: 4.8945


Epoch 2/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 2/100, Loss: 4.6826


Epoch 3/100: 100%|██████████| 4/4 [00:06<00:00,  1.68s/it]


Epoch 3/100, Loss: 4.7123


Epoch 4/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 4/100, Loss: 4.6309


Epoch 5/100: 100%|██████████| 4/4 [00:05<00:00,  1.48s/it]


Epoch 5/100, Loss: 4.5637


Epoch 6/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 6/100, Loss: 4.6192


Epoch 7/100: 100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


Epoch 7/100, Loss: 4.5205


Epoch 8/100: 100%|██████████| 4/4 [00:06<00:00,  1.65s/it]


Epoch 8/100, Loss: 4.4749


Epoch 9/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 9/100, Loss: 4.4635


Epoch 10/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 10/100, Loss: 4.3365


Epoch 11/100: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


Epoch 11/100, Loss: 4.2192


Epoch 12/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 12/100, Loss: 4.3065


Epoch 13/100: 100%|██████████| 4/4 [00:06<00:00,  1.65s/it]


Epoch 13/100, Loss: 4.1911


Epoch 14/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 14/100, Loss: 4.1826


Epoch 15/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 15/100, Loss: 4.2671


Epoch 16/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 16/100, Loss: 4.1183


Epoch 17/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 17/100, Loss: 4.2473


Epoch 18/100: 100%|██████████| 4/4 [00:05<00:00,  1.50s/it]


Epoch 18/100, Loss: 3.9782


Epoch 19/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 19/100, Loss: 3.9174


Epoch 20/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 20/100, Loss: 3.8898


Epoch 21/100: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]


Epoch 21/100, Loss: 3.8865


Epoch 22/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 22/100, Loss: 3.7227


Epoch 23/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 23/100, Loss: 3.7072


Epoch 24/100: 100%|██████████| 4/4 [00:06<00:00,  1.62s/it]


Epoch 24/100, Loss: 3.6039


Epoch 25/100: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


Epoch 25/100, Loss: 3.5888


Epoch 26/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 26/100, Loss: 3.6582


Epoch 27/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 27/100, Loss: 3.7671


Epoch 28/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 28/100, Loss: 3.4170


Epoch 29/100: 100%|██████████| 4/4 [00:06<00:00,  1.61s/it]


Epoch 29/100, Loss: 3.3508


Epoch 30/100: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


Epoch 30/100, Loss: 3.3970


Epoch 31/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 31/100, Loss: 3.1099


Epoch 32/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 32/100, Loss: 3.1995


Epoch 33/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 33/100, Loss: 3.0011


Epoch 34/100: 100%|██████████| 4/4 [00:06<00:00,  1.65s/it]


Epoch 34/100, Loss: 3.1155


Epoch 35/100: 100%|██████████| 4/4 [00:06<00:00,  1.50s/it]


Epoch 35/100, Loss: 3.2455


Epoch 36/100: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]


Epoch 36/100, Loss: 2.9124


Epoch 37/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 37/100, Loss: 2.8014


Epoch 38/100: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]


Epoch 38/100, Loss: 2.7955


Epoch 39/100: 100%|██████████| 4/4 [00:06<00:00,  1.66s/it]


Epoch 39/100, Loss: 2.6079


Epoch 40/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 40/100, Loss: 2.6201


Epoch 41/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 41/100, Loss: 2.6191


Epoch 42/100: 100%|██████████| 4/4 [00:05<00:00,  1.50s/it]


Epoch 42/100, Loss: 2.7030


Epoch 43/100: 100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


Epoch 43/100, Loss: 2.4109


Epoch 44/100: 100%|██████████| 4/4 [00:06<00:00,  1.61s/it]


Epoch 44/100, Loss: 2.4031


Epoch 45/100: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]


Epoch 45/100, Loss: 2.3545


Epoch 46/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 46/100, Loss: 2.2200


Epoch 47/100: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]


Epoch 47/100, Loss: 2.1991


Epoch 48/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 48/100, Loss: 2.1237


Epoch 49/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 49/100, Loss: 2.0677


Epoch 50/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 50/100, Loss: 1.9431


Epoch 51/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 51/100, Loss: 1.8966


Epoch 52/100: 100%|██████████| 4/4 [00:05<00:00,  1.50s/it]


Epoch 52/100, Loss: 1.6863


Epoch 53/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 53/100, Loss: 1.6497


Epoch 54/100: 100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


Epoch 54/100, Loss: 1.8207


Epoch 55/100: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]


Epoch 55/100, Loss: 1.6217


Epoch 56/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 56/100, Loss: 1.5697


Epoch 57/100: 100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


Epoch 57/100, Loss: 1.7158


Epoch 58/100: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]


Epoch 58/100, Loss: 1.6341


Epoch 59/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 59/100, Loss: 1.4677


Epoch 60/100: 100%|██████████| 4/4 [00:06<00:00,  1.65s/it]


Epoch 60/100, Loss: 1.5464


Epoch 61/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 61/100, Loss: 1.3593


Epoch 62/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 62/100, Loss: 1.2739


Epoch 63/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 63/100, Loss: 1.3253


Epoch 64/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 64/100, Loss: 1.1955


Epoch 65/100: 100%|██████████| 4/4 [00:06<00:00,  1.67s/it]


Epoch 65/100, Loss: 1.2689


Epoch 66/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 66/100, Loss: 1.1440


Epoch 67/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 67/100, Loss: 1.1868


Epoch 68/100: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it]


Epoch 68/100, Loss: 1.1577


Epoch 69/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 69/100, Loss: 1.1234


Epoch 70/100: 100%|██████████| 4/4 [00:05<00:00,  1.50s/it]


Epoch 70/100, Loss: 1.1757


Epoch 71/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 71/100, Loss: 1.1408


Epoch 72/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 72/100, Loss: 1.0321


Epoch 73/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 73/100, Loss: 1.0755


Epoch 74/100: 100%|██████████| 4/4 [00:06<00:00,  1.62s/it]


Epoch 74/100, Loss: 1.1373


Epoch 75/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 75/100, Loss: 0.9654


Epoch 76/100: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


Epoch 76/100, Loss: 1.0111


Epoch 77/100: 100%|██████████| 4/4 [00:06<00:00,  1.53s/it]


Epoch 77/100, Loss: 1.1721


Epoch 78/100: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]


Epoch 78/100, Loss: 0.9080


Epoch 79/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 79/100, Loss: 1.0149


Epoch 80/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 80/100, Loss: 0.9150


Epoch 81/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 81/100, Loss: 0.9833


Epoch 82/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 82/100, Loss: 0.9905


Epoch 83/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 83/100, Loss: 0.9349


Epoch 84/100: 100%|██████████| 4/4 [00:05<00:00,  1.49s/it]


Epoch 84/100, Loss: 0.8925


Epoch 85/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 85/100, Loss: 0.9955


Epoch 86/100: 100%|██████████| 4/4 [00:06<00:00,  1.64s/it]


Epoch 86/100, Loss: 0.9639


Epoch 87/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 87/100, Loss: 0.9983


Epoch 88/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 88/100, Loss: 0.9054


Epoch 89/100: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


Epoch 89/100, Loss: 0.8903


Epoch 90/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 90/100, Loss: 0.8624


Epoch 91/100: 100%|██████████| 4/4 [00:06<00:00,  1.74s/it]


Epoch 91/100, Loss: 0.9326


Epoch 92/100: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


Epoch 92/100, Loss: 0.9190


Epoch 93/100: 100%|██████████| 4/4 [00:06<00:00,  1.62s/it]


Epoch 93/100, Loss: 0.9830


Epoch 94/100: 100%|██████████| 4/4 [00:06<00:00,  1.62s/it]


Epoch 94/100, Loss: 0.7850


Epoch 95/100: 100%|██████████| 4/4 [00:06<00:00,  1.54s/it]


Epoch 95/100, Loss: 0.9552


Epoch 96/100: 100%|██████████| 4/4 [00:06<00:00,  1.67s/it]


Epoch 96/100, Loss: 0.9477


Epoch 97/100: 100%|██████████| 4/4 [00:06<00:00,  1.60s/it]


Epoch 97/100, Loss: 0.7988


Epoch 98/100: 100%|██████████| 4/4 [00:06<00:00,  1.55s/it]


Epoch 98/100, Loss: 0.9270


Epoch 99/100: 100%|██████████| 4/4 [00:06<00:00,  1.58s/it]


Epoch 99/100, Loss: 0.8889


Epoch 100/100: 100%|██████████| 4/4 [00:06<00:00,  1.57s/it]


Epoch 100/100, Loss: 0.9276


Embedding: 100%|██████████| 315/315 [00:05<00:00, 62.35it/s]

Эмбеддинги сохранены

Топ 5 похожих с: processed_1_79d4e30a-IMG_20221011_103356.jpg

processed_1_a8cbf437-IMG_20221011_103352.jpg (sim: 0.9157)
processed_51_9d315f77-IMG_20221011_123545.jpg (sim: 0.8278)
processed_57_83ee8236-IMG_20221011_125037.jpg (sim: 0.8176)
processed_36_d557aca7-IMG_20221011_120343.jpg (sim: 0.7204)
processed_30_d9137b1a-IMG_20221011_115024.jpg (sim: 0.6897)





In [2]:
find_similar('/kaggle/input/cow-f-crop/cow_f_crop/processed_1_79d4e30a-IMG_20221011_103356.jpg', embeddings, top_k=1)


Топ 1 похожих с: processed_1_79d4e30a-IMG_20221011_103356.jpg

processed_1_a8cbf437-IMG_20221011_103352.jpg (sim: 0.9157)


In [5]:
find_similar('/kaggle/input/cow-f-crop/cow_f_crop/processed_23_69024848-IMG_20221011_112623.jpg', embeddings, top_k=1)


Топ 1 похожих с: processed_23_69024848-IMG_20221011_112623.jpg

processed_23_8dad22cb-IMG_20221011_112624.jpg (sim: 0.9042)


In [8]:
find_similar('/kaggle/input/cow-f-crop/cow_f_crop/processed_27_a94a8a5e-IMG_20221011_113859.jpg', embeddings, top_k=1)


Топ 1 похожих с: processed_27_a94a8a5e-IMG_20221011_113859.jpg

processed_27_e0939cc5-IMG_20221011_113858.jpg (sim: 0.9660)


In [9]:
torch.save({
    'model_state_dict': model.state_dict(),
    'model_architecture': 'SimCLR',
    'img_size': IMG_SIZE,
    'backbone': 'resnet50'
}, 'simclr_cow_model.pth')
print("Модель сохранена: simclr_cow_model.pth")

Модель сохранена: simclr_cow_model.pth
