In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import Dinov2PreTrainedModel, Dinov2Model
import os
from PIL import Image
import json
import numpy as np
import albumentations as A
from torch.optim.lr_scheduler import CosineAnnealingLR
import warnings
import torch.nn.functional as F

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
warnings.filterwarnings("ignore", category=UserWarning, message="Palette images with Transparency expressed in bytes should be converted to RGBA images")

In [2]:
# Data processing
ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

train_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

val_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

In [3]:
# Load sub-elements from txt file
def load_sub_elements(txt_file):
    sub_elements = []
    with open(txt_file, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split('.png ')
            if len(parts) >= 2:
                filename = parts[0]
                # print(filename)
                category = parts[1]
                sub_elements.append({
                    "filename": filename,
                    "category": category
                })
                # print(sub_elements)
    return sub_elements

# Load image database from JSON file
def load_image_database(json_file):
    image_db = []
    with open(json_file, 'r') as f:
        data = json.load(f)
        for item in data:
            num_categories = int(len(item)/2-1)
            # print(num_categories)
            if len(item) >= 2:
                filename = item[1]
                categories = item[2: num_categories+2]
                image_db.append({
                    "filename": filename,
                    "categories": categories
                })
    return image_db

In [4]:
# Triplet dataset building
class TripletDataset(Dataset):
    def __init__(self, sub_elements, image_db, sub_elements_dir, images_dir, transform=None):
        self.sub_elements = sub_elements
        self.image_db = image_db
        self.sub_elements_dir = sub_elements_dir
        self.images_dir = images_dir
        self.transform = transform

        self.category_to_images = {}
        for img in self.image_db:
            for cat in img["categories"]:
                if cat not in self.category_to_images:
                    self.category_to_images[cat] = []
                self.category_to_images[cat].append(img) ##{"id": {filename=Random1, categories={id1, id2, id3}}}

        self.category_exclude_images = {}
        all_image_indices = set(range(len(self.image_db)))
        for cat in self.category_to_images:
            cat_image_indices = {i for i, img in enumerate(self.image_db) if cat in img["categories"]}
            self.category_exclude_images[cat] = list(all_image_indices - cat_image_indices)
    
    def __len__(self):
        return len(self.sub_elements)
    
    def __getitem__(self, idx):
        sub_element = self.sub_elements[idx]
        sub_element_filename = sub_element["filename"]
        sub_element_category = sub_element["category"]
        sub_element_id = sub_element_filename.split('/')[-1].split(',')[0].strip()

        sub_element_path = sub_element_filename + str('.png')
        try:
            sub_element_image = Image.open(sub_element_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {sub_element_path}: {e}")
            sub_element_image = Image.new('RGB', (224, 224))
        
        sub_element_image = np.array(sub_element_image)
        positive_images = self.category_to_images.get(sub_element_id, [])

        if not positive_images:
            print(f"Warning: No positive images found for category {sub_element_id} in image_db")
            positive_img = self.image_db[np.random.randint(0, len(self.image_db))]
        else:
            positive_img = np.random.choice(positive_images)
            assert sub_element_id in positive_img["categories"], "Selected positive image does not contain target category"

        positive_filename = positive_img["filename"]
        positive_path = os.path.join(self.images_dir, f"{positive_filename}.png")
        try:
            positive_image = Image.open(positive_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {positive_path}: {e}")
            positive_image = Image.new('RGB', (224, 224))
        
        positive_image = np.array(positive_image)
        negative_images = self.category_exclude_images.get(sub_element_id, [])
        
        if not negative_images:
            print(f"Warning: No negative images found for category {sub_element_id} in image_db")
            negative_img = self.image_db[np.random.randint(0, len(self.image_db))]
        else:
            negative_img = self.image_db[np.random.choice(negative_images)]
            assert sub_element_id not in negative_img["categories"], "Selected negative image contains target category"

        negative_filename = negative_img["filename"]
        negative_path = os.path.join(self.images_dir, f"{negative_filename}.png")
        try:
            negative_image = Image.open(negative_path).convert('RGB')
        except Exception as e:
            print(f"Error loading {negative_path}: {e}")
            negative_image = Image.new('RGB', (224, 224))
        # print(sub_element, positive_filename, negative_filename)
        negative_image = np.array(negative_image)
        
        # print(sub_element_id, positive_filename, negative_filename)
        if self.transform:
            transformed_anchor = self.transform(image=sub_element_image)
            anchor = transformed_anchor["image"]

            transformed_positive = self.transform(image=positive_image)
            positive = transformed_positive["image"]

            transformed_negative = self.transform(image=negative_image)
            negative = transformed_negative["image"]

        anchor = torch.tensor(anchor).permute(2, 0, 1).float()
        positive = torch.tensor(positive).permute(2, 0, 1).float()
        negative = torch.tensor(negative).permute(2, 0, 1).float()
        
        return {
            "anchor": anchor,
            "positive": positive,
            "negative": negative,
            "anchor_path": sub_element_path,
            "positive_path": positive_path,
            "negative_path": negative_path,
            "category": sub_element_category,
            "positive_categories": positive_img["categories"],
            "negative_categories": negative_img["categories"]
        }

In [5]:
# Dinov2
class Dinov2FeatureExtractor(Dinov2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.dinov2 = Dinov2Model(config)
        self.projection_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Linear(config.hidden_size, 256),
        )

        for param in self.dinov2.parameters():
            param.requires_grad = False

        self._unfreeze_dinov2_layers(2)
        for param in self.projection_head.parameters():
            param.requires_grad = False
    
    def _unfreeze_dinov2_layers(self, unfreeze_layers):
        try:
            total_blocks = len(self.dinov2.encoder.layer)
            layers_to_unfreeze = max(0, total_blocks - unfreeze_layers)
            
            print(f"Unfreeze the last {unfreeze_layers} Transformer blocks ({layers_to_unfreeze}-{total_blocks-1})")

            for i in range(layers_to_unfreeze, total_blocks):
                for param in self.dinov2.encoder.layer[i].parameters():
                    param.requires_grad = True
                print(f"Unfreeze block {i}")

            for param in self.dinov2.layernorm.parameters():
                param.requires_grad = True
            print("Unfreeze layernorm layer")
                
        except Exception as e:
            print(f"Error occurred during unfreezing: {e}")
            print("Only train the projection head")
    
    def forward(self, pixel_values, output_hidden_states=False, output_attentions=False,return_attentions=False):
        outputs = self.dinov2(
            pixel_values,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions
        )

        cls_token = outputs.last_hidden_state[:, 0, :]  # [batch_size, hidden_size]

        features = self.projection_head(cls_token)  # [batch_size, 256]
        if return_attentions:
            return features, outputs.last_hidden_state, outputs.hidden_states, outputs.attentions
        else:
            # return query_feat, target_feat, align_feat
            return {
                'features': features,
                'last_hidden_state': outputs.last_hidden_state,
                'hidden_states': outputs.hidden_states,
                'attentions': outputs.attentions
            }

In [6]:
# Query-Guided Attention Module
class QueryGuidedAttention(nn.Module):
    def __init__(self, hidden_size=768, num_heads=8, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
    
    def forward(self, query_global, target_spatial):
        """
        query_global: [batch_size, 1, hidden_size] sub-element features
        target_spatial: [batch_size, seq_len, hidden_size] image features 
        """
        context, attn_weights = self.multihead_attn(
            query=query_global,
            key=target_spatial,
            value=target_spatial,
            need_weights=True
        )
        return context, attn_weights

In [7]:
class L2Norm(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)

class AttentionFeatureExtractor(Dinov2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.dinov2 = Dinov2Model(config)
        self.cross_attention = QueryGuidedAttention(hidden_size=config.hidden_size, num_heads=8, dropout=0.1)

        self.query_projection = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.GELU(),
            L2Norm()
        )
        
        self.target_projection = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            nn.GELU(),
            L2Norm()
        )

        self.attention_align = nn.Sequential(
            nn.Linear(config.hidden_size, 256),
            L2Norm()
        )

        for param in self.dinov2.parameters():
            param.requires_grad = False

        self._unfreeze_dinov2_layers(2)

        for param in self.query_projection.parameters():
            param.requires_grad = True
        for param in self.target_projection.parameters():
            param.requires_grad = True
        for param in self.attention_align.parameters():
            param.requires_grad = True
        for param in self.cross_attention.parameters():
            param.requires_grad = True
    
    def _unfreeze_dinov2_layers(self, unfreeze_layers):
        try:
            total_blocks = len(self.dinov2.encoder.layer)
            layers_to_unfreeze = max(0, total_blocks - unfreeze_layers)
            
            print(f"Unfreeze the last {unfreeze_layers} Transformer blocks ({layers_to_unfreeze}-{total_blocks-1})")
 
            for i in range(layers_to_unfreeze, total_blocks):
                for param in self.dinov2.encoder.layer[i].parameters():
                    param.requires_grad = True
                print(f"Unfreeze block {i}")

            for param in self.dinov2.layernorm.parameters():
                param.requires_grad = True
            print("Unfreeze layernorm layer")
                
        except Exception as e:
            print(f"Error occurred during unfreezing: {e}")
            print("Only train the projection head")
    
    def forward(self, query_images, target_images, output_hidden_states=False, output_attentions=False, is_train=True, return_attentions=False):
        query_outputs = self.dinov2(
            query_images,
            output_hidden_states=output_hidden_states,
            output_attentions=output_attentions
        )
        query_global = query_outputs.last_hidden_state[:, :1, :] # Sub-element takes the CLS token as feature

        if is_train and target_images is not None:
            target_outputs = self.dinov2(target_images)
            target_spatial = target_outputs.last_hidden_state[:, 1:, :] # Target image takes the patch tokens as features

            context, attn_weights = self.cross_attention(
                query_global, 
                target_spatial
            ) # query: query_global, key/value: target_spatial

            query_feat = self.query_projection(query_global.squeeze(1))
            target_feat = self.target_projection(context.squeeze(1))
            align_feat = self.attention_align(query_global.squeeze(1)) 

            if return_attentions:
                return query_feat, target_feat, align_feat, attn_weights
            else:
                return query_feat, target_feat, align_feat
        else:
            return self.query_projection(query_global.squeeze(1))

In [8]:
class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.distance_fn = nn.CosineSimilarity(dim=1)
    
    def forward(self, anchor, positive, negative):
        pos_sim = self.distance_fn(anchor, positive)
        neg_sim = self.distance_fn(anchor, negative)

        loss = torch.mean(torch.clamp(self.margin - pos_sim + neg_sim, min=0.0))
        return loss

In [None]:
def main():
    # Data path
    sub_elements_txt = '/Dataset/train_sub.txt'  # Path and category of sub elements training dataset
    val_sub_txt = '/Dataset/val_sub.txt' # Path and category of sub elements validating dataset
    image_db_json = '/Dataset/Train_split.json'  # Image database for training
    val_image_db = '/Dataset/Validation_split.json' # Image database for validating
    sub_elements_dir = '/Dataset/element img'  # Sub-element images directory
    images_dir = '/Dataset/SimulatedPrintedFabrics-17k/train/img/images'  # Images for training
    val_dir = '/Dataset/SimulatedPrintedFabrics-17k/validation/img/images' # Images for validating
    
    # Load data
    print("Loading sub-elements...")
    sub_elements = load_sub_elements(sub_elements_txt)
    val_sub = load_sub_elements(val_sub_txt)
    print(f"Loaded {len(sub_elements)} sub-elements")
    
    print("Loading image database...")
    image_db = load_image_database(image_db_json)
    val_db = load_image_database(val_image_db)
    print(f"Loaded {len(image_db)} images")

    # Dataset and DataLoader
    train_size = len(sub_elements)
    val_size = len(val_sub)
    print(train_size, val_size)
    train_sub_elements = sub_elements
    val_sub_elements = val_sub
    
    train_dataset = TripletDataset(train_sub_elements, image_db, sub_elements_dir, images_dir, train_transform)
    val_dataset = TripletDataset(val_sub_elements, val_db, sub_elements_dir, val_dir, val_transform)
    
    def collate_fn(batch):
        anchor_batch = torch.stack([item["anchor"] for item in batch])
        positive_batch = torch.stack([item["positive"] for item in batch])
        negative_batch = torch.stack([item["negative"] for item in batch])
        
        return {
            "anchor": anchor_batch,
            "positive": positive_batch,
            "negative": negative_batch,
            "anchor_paths": [item["anchor_path"] for item in batch],
            "positive_paths": [item["positive_path"] for item in batch],
            "negative_paths": [item["negative_path"] for item in batch],
            "categories": [item["category"] for item in batch]
        }
    
    train_dataloader = DataLoader(train_dataset, batch_size=6, shuffle=True, collate_fn=collate_fn, num_workers=4)
    val_dataloader = DataLoader(val_dataset, batch_size=6, shuffle=False, collate_fn=collate_fn, num_workers=4)
    
    # Model initialization
    model = AttentionFeatureExtractor.from_pretrained("/Weight_Path/dinov2-pytorch-base-v1")
    model.to(device)

    # Definition of loss functions
    triplet_criterion = TripletLoss(margin=0.2)
    
    lr = 1e-5
    params = [
        {'params': [p for n, p in model.named_parameters() if 'dinov2' in n and p.requires_grad], 'lr': lr/10},
        {'params': [p for n, p in model.named_parameters() if 'query_projection' in n], 'lr': lr},
        {'params': [p for n, p in model.named_parameters() if 'target_projection' in n], 'lr': lr},
        {'params': [p for n, p in model.named_parameters() if 'attention_align' in n], 'lr': lr},
        {'params': [p for n, p in model.named_parameters() if 'cross_attention' in n], 'lr': lr*2}  
    ]
    feature_optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=5e-4)
    feature_scheduler = CosineAnnealingLR(feature_optimizer, T_max=80, eta_min=lr/100)
    
    # Training function
    def train_one_epoch(feature_extractor, dataloader, triplet_criterion, feature_optimizer, device):
        feature_extractor.train()
        total_triplet_loss = 0.0
        
        for batch in dataloader:
            anchor = batch["anchor"].to(device)
            positive = batch["positive"].to(device)
            negative = batch["negative"].to(device)

            feature_optimizer.zero_grad()

            # Sub-element feature extraction
            anchor_features = feature_extractor(query_images=anchor, target_images=None)[0]
            
            # Positive feature extraction (using anchor to query positive image)
            _, positive_features, _ = feature_extractor(
                query_images=anchor,
                target_images=positive
            )
            
            # Negative feature extraction (using anchor to query negative image)
            _, negative_features, _ = feature_extractor(
                query_images=anchor,
                target_images=negative
            )

            triplet_loss = triplet_criterion(anchor_features, positive_features, negative_features)

            # Backpropagation and optimization
            triplet_loss.backward()
            feature_optimizer.step()
            
            total_triplet_loss += triplet_loss.item()
            
        return total_triplet_loss / len(dataloader)
    
    # Evaluation function
    def evaluate(feature_extractor, dataloader, triplet_criterion, device):
        feature_extractor.eval()

        total_triplet_loss = 0.0
        
        with torch.no_grad():
            for batch in dataloader:
                anchor = batch["anchor"].to(device)
                positive = batch["positive"].to(device)
                negative = batch["negative"].to(device)
                
                # Sub-element feature extraction
                anchor_features = feature_extractor(query_images=anchor, target_images=None)[0]
            
                # Positive feature extraction (using anchor to query positive image)
                _, positive_features, _ = feature_extractor(
                    query_images=anchor,
                    target_images=positive
                )
                
                # Negative feature extraction (using anchor to query negative image)
                _, negative_features, _ = feature_extractor(
                    query_images=anchor,
                    target_images=negative
                )
                
                # Loss computation
                triplet_loss = triplet_criterion(anchor_features, positive_features, negative_features)
                total_triplet_loss += triplet_loss.item()
        
        return total_triplet_loss / len(dataloader)
    
    # Training loop
    num_epochs = 100
    print('Starting training...')
    for epoch in range(num_epochs):
        
        train_triplet_loss = train_one_epoch(
            model, train_dataloader, 
            triplet_criterion,
            feature_optimizer, device
        )
        
        val_triplet_loss = evaluate(
            model, val_dataloader, 
            triplet_criterion, device
        )

        feature_scheduler.step()
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: Triplet={train_triplet_loss:.4f}")
        print(f"Val Loss:   Triplet={val_triplet_loss:.4f}")

        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'feature_extractor_state_dict': model.state_dict(),
                'feature_optimizer_state_dict': feature_optimizer.state_dict(),
                'train_triplet_loss': train_triplet_loss,
                'val_triplet_loss': val_triplet_loss,
            }, f'fine_tuned/dinov2_query_epoch_{epoch+1}.pth')

    
if __name__ == "__main__":
    main()