In [1]:
import scipy.io as sio
import numpy as np
import os
import tensorboard
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter
from utils.model import DownSampleConv, UpSampleConv

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [2]:
D = sio.loadmat('/home/hou63/pj2/Nematic_RL/datas/D.mat')['D']
# print keys
print(D.shape)
d11 = D[:,:,0,40:]
d12 = D[:,:,1,40:]
print(d11.shape)
# put axis 2 to 0
d11 = np.moveaxis(d11, 2, 0)
d12 = np.moveaxis(d12, 2, 0)
print(d11.shape)
ds = np.stack((d11, d12), axis=1)
print(ds.shape)
ds = torch.tensor(ds, dtype=torch.float32, device=device)

(256, 256, 3, 257)
(256, 256, 217)
(217, 256, 256)
(217, 2, 256, 256)


In [None]:
# plot D
plt.figure()
plt.imshow(d11[:,:,0])
plt.colorbar()
plt.show()

In [3]:

encoder = DownSampleConv()
decoder = UpSampleConv()

encoder = encoder.to(device)
decoder = decoder.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(encoder.parameters(), lr=0.001)

# 训练参数
num_epochs = 10
dataset = TensorDataset(ds)  # 将 ds 封装为 dataset
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [4]:
num_epochs = 200
save_path = '/home/hou63/pj2/Nematic_RL/log_model/encoder_checkpoint.pth'  # 模型保存路径
best_loss = float('inf')  # 初始最优损失值设为无穷大

# 开始训练
for epoch in range(num_epochs):
    encoder.train()  # 设置 encoder 为训练模式
    decoder.train()  # 设置 decoder 为训练模式
    
    total_loss = 0.0

    for data in dataloader:
        # 将数据移到设备上
        x = data[0].to(device)
        
        # 前向传播: encoder 和 decoder
        encoded = encoder(x)
        reconstructed = decoder(encoded)
        
        # 计算损失
        loss = criterion(reconstructed, x)  # 自监督损失：重建误差
        
        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
    
    # 保存模型：当当前损失小于最佳损失时，保存模型
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(encoder.state_dict(), save_path)
        print(f'Model saved with loss {best_loss:.4f}')

Epoch [1/200], Loss: 0.2509
Model saved with loss 0.2509
Epoch [2/200], Loss: 0.2440
Model saved with loss 0.2440
Epoch [3/200], Loss: 0.2394
Model saved with loss 0.2394
Epoch [4/200], Loss: 0.2351
Model saved with loss 0.2351
Epoch [5/200], Loss: 0.2319
Model saved with loss 0.2319
Epoch [6/200], Loss: 0.2301
Model saved with loss 0.2301
Epoch [7/200], Loss: 0.2295
Model saved with loss 0.2295
Epoch [8/200], Loss: 0.2287
Model saved with loss 0.2287
Epoch [9/200], Loss: 0.2282
Model saved with loss 0.2282
Epoch [10/200], Loss: 0.2276
Model saved with loss 0.2276
Epoch [11/200], Loss: 0.2272
Model saved with loss 0.2272
Epoch [12/200], Loss: 0.2267
Model saved with loss 0.2267
Epoch [13/200], Loss: 0.2265
Model saved with loss 0.2265
Epoch [14/200], Loss: 0.2260
Model saved with loss 0.2260
Epoch [15/200], Loss: 0.2257
Model saved with loss 0.2257
Epoch [16/200], Loss: 0.2251
Model saved with loss 0.2251
Epoch [17/200], Loss: 0.2250
Model saved with loss 0.2250
Epoch [18/200], Loss: 0