In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pennylane as qml
import timm


class ViTEncoder(nn.Module):
    def __init__(self, model_name='vit_large_patch16_224', embedding_dim=1024,
                 pretrained=True, freeze_backbone=True):
        super().__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        self.vit_dim = self.vit.embed_dim
        self.freeze_backbone = freeze_backbone
        if freeze_backbone:
            for param in self.vit.parameters():
                param.requires_grad = False
        self.projector = nn.Sequential(
            nn.Linear(self.vit_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(512, embedding_dim)
        )
        self.n_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def forward(self, x, return_features=False):
        if self.freeze_backbone:
            self.vit.eval()
            with torch.no_grad():
                features = self.vit(x)
        else:
            features = self.vit(x)
        if return_features:
            return features
        embedding = self.projector(features)
        return F.normalize(embedding, dim=1)


class QuantumEnhancer(nn.Module):
    def __init__(self, input_dim=1024, n_qubits=10, n_qlayers=3):
        super().__init__()
        self.n_qubits = n_qubits
        self.n_qlayers = n_qlayers
        self.quantum_dim = 2 ** n_qubits
        self.projection = nn.Linear(input_dim, self.quantum_dim, bias=True)
        nn.init.xavier_uniform_(self.projection.weight)
        dev = qml.device("default.qubit", wires=n_qubits)
        
        @qml.qnode(dev, interface="torch", diff_method="backprop")
        def quantum_circuit(inputs, weights):
            qml.AmplitudeEmbedding(inputs, wires=range(n_qubits), normalize=True)
            qml.templates.BasicEntanglerLayers(weights, wires=range(n_qubits))
            return qml.state()
        
        weight_shape = qml.templates.BasicEntanglerLayers.shape(
            n_layers=n_qlayers, n_wires=n_qubits
        )
        self.quantum_layer = qml.qnn.TorchLayer(quantum_circuit, {"weights": weight_shape})
        self.n_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def forward(self, embeddings):
        projected = self.projection(embeddings)
        quantum_state = self.quantum_layer(projected)
        quantum_amplitudes = torch.abs(quantum_state)
        return quantum_amplitudes


class QuPIDModel(nn.Module):
    def __init__(self, vit_encoder, quantum_enhancer):
        super().__init__()
        self.vit_encoder = vit_encoder
        self.quantum_enhancer = quantum_enhancer
    
    def forward(self, x):
        vit_embedding = self.vit_encoder(x)
        quantum_amplitudes = self.quantum_enhancer(vit_embedding)
        return quantum_amplitudes
    
    def get_vit_embedding(self, x):
        return self.vit_encoder(x)


def nt_xent_loss(z1, z2, temperature=0.07):
    batch_size = z1.shape[0]
    z = torch.cat([z1, z2], dim=0)
    sim_matrix = torch.matmul(z, z.T) / temperature
    mask = torch.eye(2 * batch_size, device=z.device, dtype=torch.bool)
    sim_matrix.masked_fill_(mask, -1e9)
    pos_sim = torch.cat([
        torch.diag(sim_matrix, batch_size),
        torch.diag(sim_matrix, -batch_size)
    ])
    exp_sim = torch.exp(sim_matrix)
    loss = -torch.log(torch.exp(pos_sim) / exp_sim.sum(dim=1)).mean()
    return loss


def quantum_contrastive_loss(q_state1, q_state2, temperature=0.07):
    batch_size = q_state1.shape[0]
    q_state1 = F.normalize(q_state1, dim=-1)
    q_state2 = F.normalize(q_state2, dim=-1)
    fidelity_matrix = torch.matmul(q_state1, q_state2.T) ** 2
    logits = fidelity_matrix / temperature
    labels = torch.arange(batch_size, device=q_state1.device)
    loss = F.cross_entropy(logits, labels)
    return loss


def combined_contrastive_loss(z1, z2, q1, q2, temperature=0.07, quantum_weight=0.5):
    classical_loss = nt_xent_loss(z1, z2, temperature)
    quantum_loss = quantum_contrastive_loss(q1, q2, temperature)
    total_loss = (1 - quantum_weight) * classical_loss + quantum_weight * quantum_loss
    return total_loss, classical_loss, quantum_loss