In [None]:
import torch
import itertools
import torch.nn as nn
from torchsummary import summary
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from dataloader import dataset
from matplotlib import pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
def std_loss(z_a, z_b):
    std_z_a = torch.sqrt(z_a.var(dim=0) + 1e-04)
    std_z_b = torch.sqrt(z_b.var(dim=0) + 1e-04)
    std_loss = torch.mean(F.relu(1 - std_z_a)) + torch.mean(F.relu(1 - std_z_b))
    return std_loss


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

# covariance loss
def cov_loss(z_a, z_b):
    N = z_a.shape[0]
    D = z_a.shape[1]
    z_a = z_a - z_a.mean(dim=0)
    z_b = z_b - z_b.mean(dim=0)
    cov_z_a = (z_a.T @ z_a) / (N - 1)
    cov_z_b = (z_b.T @ z_b) / (N - 1)
    cov_loss = off_diagonal(cov_z_a).pow_(2).sum() / D + off_diagonal(cov_z_b).pow_(2).sum() / D
    return cov_loss


In [None]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(identity)
        out += identity
        out = self.relu(out)
        return out

class ResNet3D(nn.Module):
    def __init__(self, block, layers, in_channels=1):
        super(ResNet3D, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, 4096)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def ResNet18_3D(num_classes=1000, in_channels=3):
    return ResNet3D(BasicBlock, [2, 2, 2, 2], in_channels=in_channels)



In [None]:

model = ResNet18_3D(in_channels=1).to(device)
input_tensor = torch.randn((2, 1, 64, 64, 64)).to(device)
output = model(input_tensor)
output.shape

In [None]:
Trainingset = dataset(file_path1="./reg_data/00/",file_path2="./reg_data/12/",force=0)
trainingloader = DataLoader(dataset=Trainingset,batch_size=4,shuffle=True)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001, betas=(0.5, 0.999))

In [None]:
losses = []
for epoch in range(400):
    loss_total = 0.0
    mse_loss = 0.0
    for inputs, targets, force in tqdm(trainingloader, desc=f"Epoch {epoch+1}/{400}"):

        real_A = inputs.to(device).unsqueeze(1).float()
        real_B = targets.to(device).unsqueeze(1).float()

        optimizer.zero_grad()

        repr_a = model(real_A)
        repr_b = model(real_B)

        _sim_loss = criterion(repr_a, repr_b)
        _std_loss = std_loss(repr_a, repr_b)
        _cov_loss = cov_loss(repr_a, repr_b)

        loss = 25 * _sim_loss + 25 * _std_loss + 1 * _cov_loss
        
        loss.backward()
        optimizer.step()

        loss_total += loss.item() * inputs.size(0)
        mse_loss += _sim_loss.item() * inputs.size(0)

    loss_total /= len(trainingloader.dataset)
    mse_loss /= len(trainingloader.dataset)
    losses.append(loss_total)

    print(f"Epoch [{epoch+1}/{400}], Loss: {loss_total:.4f}, mse: {mse_loss:.4f}")

In [None]:
plt.plot(losses, label='Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Loss over Epochs')
plt.legend()
plt.show()

In [None]:
torch.save(model.state_dict(), './saved_model/ResNet/model400.pth')

In [None]:
mse = 0.0
for inputs, targets,_ in trainingloader:
    real_A = inputs.to(device).unsqueeze(1).float()
    real_B = targets.to(device).unsqueeze(1).float()

    optimizer.zero_grad()

    repr_a = model(real_A)
    repr_b = model(real_B)
    
    loss = criterion(repr_a, repr_b)
    
    
    mse += loss.item() * inputs.size(0)

mse /= len(trainingloader.dataset)

In [None]:
mse