In [None]:
model = SiameseNetwork().to(device)
criterion = WeightedContrastiveLoss(margin=1.0, pos_weight=1.0, neg_weight=1.0)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

scheduler = OneCycleLR(
    optimizer,
    max_lr=3e-4,                
    steps_per_epoch=len(train_loader),
    epochs=num_epochs,
    pct_start=0.3,               
    anneal_strategy='cos',       
    div_factor=25.0,             
    final_div_factor=1e4,       
    three_phase=False,          
)

In [None]:
class WeightedContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0, pos_weight=1.0, neg_weight=1.0):
        super(WeightedContrastiveLoss, self).__init__()
        self.margin = margin
        self.pos_weight = pos_weight
        self.neg_weight = neg_weight

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)

        label_pos = (label == 1).float()
        label_neg = (label == -1).float()

        loss_positive = label_pos * torch.pow(euclidean_distance, 2)
        loss_negative = label_neg * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)

        weighted_loss = self.pos_weight * loss_positive + self.neg_weight * loss_negative
        return torch.mean(weighted_loss)
        
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)

    def forward(self, x):
        b, c = x.size()
        y = F.adaptive_avg_pool1d(x.unsqueeze(2), 1).view(b, c)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y))
        return x * y

class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = models.densenet121(pretrained=True)

        backbone.features.conv0 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.feature_extractor = backbone.features

        self.se_block = SEBlock(1024, reduction=16)

        self.embedding = nn.Sequential(
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
        )

    def forward_once(self, x):
        x = self.feature_extractor(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.se_block(x)
        x = self.embedding(x)
        return x

    def forward(self, input1, input2):
        def pack(x):
            img, lbp = x
            return torch.cat([img, lbp], dim=1)

        x1 = pack(input1)
        x2 = pack(input2)
        out1 = self.forward_once(x1)
        out2 = self.forward_once(x2)
        return out1, out2