In [1]:
import scipy.io as sio
import numpy as np
import os
import tensorboard
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter
from utils.model import DownSampleConv, UpSampleConv

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [2]:
D = sio.loadmat('/home/hou63/pj2/Nematic_RL/datas/D.mat')['D']
# print keys
print(D.shape)
d11 = D[:,:,0,40:]
d12 = D[:,:,1,40:]
print(d11.shape)
# put axis 2 to 0
d11 = np.moveaxis(d11, 2, 0)
d12 = np.moveaxis(d12, 2, 0)
print(d11.shape)
ds = np.stack((d11, d12), axis=1)
print(ds.shape)
ds = torch.tensor(ds, dtype=torch.float32, device=device)

(256, 256, 3, 257)
(256, 256, 217)
(217, 256, 256)
(217, 2, 256, 256)


In [None]:
# plot D
plt.figure()
plt.imshow(d11[:,:,0])
plt.colorbar()
plt.show()

In [None]:

encoder = DownSampleConv()
decoder = UpSampleConv()

encoder = encoder.to(device)
decoder = decoder.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(encoder.parameters(), lr=0.001)

# 训练参数
num_epochs = 10
dataset = TensorDataset(ds)  # 将 ds 封装为 dataset
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [5]:
for epoch in range(num_epochs):
    encoder.train()  # 设置 encoder 为训练模式
    decoder.train()  # 设置 decoder 为训练模式
    
    total_loss = 0.0

    for data in dataloader:
        # 将数据移到设备上
        x = data[0].to(device)
        
        # 前向传播: encoder 和 decoder
        encoded = encoder(x)
        reconstructed = decoder(encoded)
        
        # 计算损失
        loss = criterion(reconstructed, x)  # 自监督损失：重建误差
        
        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

Epoch [1/10], Loss: 0.2340
Epoch [2/10], Loss: 0.2274
Epoch [3/10], Loss: 0.2239
Epoch [4/10], Loss: 0.2216
Epoch [5/10], Loss: 0.2196
Epoch [6/10], Loss: 0.2183
Epoch [7/10], Loss: 0.2176
Epoch [8/10], Loss: 0.2168
Epoch [9/10], Loss: 0.2164
Epoch [10/10], Loss: 0.2164
