1. 检查可用GPU

In [1]:
import torch
import os

# 检查CUDA是否可用
if torch.cuda.is_available():
    # 获取当前正在使用的GPU设备
    torch.cuda.set_device(5)
    device = torch.cuda.current_device()
    print("Current GPU device: ", device)
    device_count = torch.cuda.device_count()
    print(f"Number of available GPUs: {device_count}")

    # 获取每个GPU的名称
    for i in range(device_count):
        device_name = torch.cuda.get_device_name(i)
        print(f"Device {i}: {device_name}")
else:
    print("No CUDA devices available")

for i in range(device_count):
    properties = torch.cuda.get_device_properties(i)
    print(f"GPU {i}: {properties.multi_processor_count} cores")

print(f"cwd:{os.getcwd()}")


Current GPU device:  5
Number of available GPUs: 8
Device 0: NVIDIA GeForce RTX 3090
Device 1: NVIDIA GeForce RTX 3090
Device 2: NVIDIA GeForce RTX 3090
Device 3: NVIDIA GeForce RTX 3090
Device 4: NVIDIA GeForce RTX 3090
Device 5: NVIDIA GeForce RTX 3090
Device 6: NVIDIA GeForce RTX 3090
Device 7: NVIDIA GeForce RTX 3090
GPU 0: 82 cores
GPU 1: 82 cores
GPU 2: 82 cores
GPU 3: 82 cores
GPU 4: 82 cores
GPU 5: 82 cores
GPU 6: 82 cores
GPU 7: 82 cores
cwd:/home/lihaiyue/data/snapshotscope/replicate/exp/fouriernet_mse_3000


2. import 需用模块

In [2]:
import datetime
import time
import logging
import os
import subprocess
import math
import numpy as np
import pandas as pd

import torch
import torch.cuda.comm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.cuda._utils import _get_device_index
from collections import OrderedDict

import sys
sys.path.append('/home/lihaiyue/data/snapshotscope/replicate')
from utils.control import *
from utils.output_control import *
from utils.networks import *

dataset创建函数

In [3]:
from torch.utils.data import Dataset

class DiffuserMirflickrDataset(Dataset):
    """
    Dataset for loading pairs of diffused images collected through DiffuserCam
    and ground truth, unblurred images collected through a DSLR camera. For use
    with DLMD (DiffuserCam Lensless Mirflickr Dataset). Optionally supports any
    callable transform.

    Args:
        csv_path: Path to .csv file containing filenames of images (both
        diffused and ground truth images share the same filename).

        data_dir: Path to directory containing diffused image data.

        label_dir: Path to directory containing ground truth image data.

        transform (optional): An optional callable that will be applied to
        every image pair. Defaults to None, in which case nothing happens.
    """

    def __init__(self, csv_path, data_dir, label_dir, transform=None):
        super().__init__()
        self.csv_contents = pd.read_csv(csv_path)
        self.data_dir = data_dir
        self.label_dir = label_dir
        self.transform = transform

    def __len__(self):
        return len(self.csv_contents)

    def __getitem__(self, idx):

        img_name = self.csv_contents.iloc[idx, 0]

        path_diffused = os.path.join(self.data_dir, img_name)
        path_gt = os.path.join(self.label_dir, img_name)

        #image = np.load(path_diffused[0:-9] + ".npy") 
        image = np.load(path_diffused+ ".npy") 
        #label = np.load(path_gt[0:-9] + ".npy")
        label = np.load(path_gt + ".npy")
        image = image.transpose((2, 0, 1))
        label = label.transpose((2, 0, 1))
        image = torch.from_numpy(image)
        label = torch.from_numpy(label)

        sample = {"image": image, "label": label}

        if self.transform:
            sample = self.transform(sample)

        return sample


exp代码

进行基础设定

In [4]:
import logging

logging.basicConfig(filename="out.log", level=logging.DEBUG, format="%(message)s")

# define optimization hyperparameters
learning_rate = 1e-4
lpips_weight = 0
lpips_step_size = 0.1
lpips_step_milestones = []
num_iterations = 3000

# setup the simulation parameters, make a microscope
image_shape = (1080, 1920)
num_chunks = 1
devices = [torch.cuda.current_device()]
print(f"devices={devices}")

# calculate downsampled sizes
downsample = 4
downsampled_image_shape = [int(s / downsample) for s in image_shape]


devices=[5]


输入各个待用模块的函数

In [5]:
def create_reconstruction_network():
    # create multi gpu reconstruction network list for 1 gpu
    deconv = FourierNetRGB(
        20, 
        downsampled_image_shape,
        fourier_conv_args={"stride": 2},
        conv_kernel_sizes=[(11, 11), (11, 11), (11, 11)],
        conv_fmap_nums=[64, 64, 3],
        input_scaling_mode=None,
        device=devices[0],
    )
    return deconv

def initialize_reconstruction(latest=None):
    deconv = create_reconstruction_network()
    if latest is not None:
        print("[info] loading from checkpoint")
        deconv.load_state_dict(latest["deconv_state_dict"], strict=True) 
    return deconv

def initialize_optimizer(deconv, latest=None):
    # optimize microscope parameters and reconstruction network
    opt = optim.Adam(
        [{"params": deconv.parameters(), "lr": learning_rate}], lr=learning_rate
    ) #parameters()再研究下
    if latest is not None:
        opt.load_state_dict(latest["opt_state_dict"])
    return opt

def create_dataset(test=False):
    base_path = "/home/lihaiyue/data/snapshotscope/data/dlmd/dataset"
    data_dir = os.path.join(base_path, "diffuser_images")
    label_dir = os.path.join(base_path, "ground_truth_lensed")
    if not test:
        csv_path = os.path.join(base_path, "dataset_train.csv")
        dataset = DiffuserMirflickrDataset(csv_path, data_dir, label_dir)
    else:
        csv_path = os.path.join(base_path, "dataset_test.csv")
        dataset = DiffuserMirflickrDataset(csv_path, data_dir, label_dir) 
    return dataset
 #DiffuserMirflickrDataset 模块在dataloader.py中，返回dic：sample = {"image": image, "label": label}，image，label均为torch

def create_dataloader(dataset, test=False):
    if not test:
        dataloader = torch.utils.data.DataLoader(
            dataset, num_workers=10, batch_size=30, shuffle=True #num_workers用于指定在加载数据时使用的子进程数
        )
    else:
        dataloader = torch.utils.data.DataLoader(
            dataset, num_workers=4, batch_size=1, shuffle=False #test dataloader的batch size必须为1
        )
    return dataloader

开始训练！

initialize model and optimizer

In [None]:
# initialize model for training
if os.path.exists("latest.pt"):
    latest = torch.load("latest.pt")
else:
    latest = None
deconv = initialize_reconstruction(latest=latest)
print(deconv)

# initialize optimizer
opt = initialize_optimizer(deconv, latest=latest)

initialize data

In [None]:
# initialize data 
dataset = create_dataset()
#TODO：change split numbers
dataset, val_dataset = torch.torch.utils.data.random_split(
    dataset, [900, 12], generator=torch.Generator().manual_seed(42)
)
dataloader = create_dataloader(dataset) 
val_dataloader = create_dataloader(val_dataset, test=True)

initialize iteration count

In [None]:
# initialize iteration count
if latest is not None:
    latest_iter = latest["it"]
    mses = latest["mses"]
    #lpips_losses = latest["lpips_losses"]
    validate_mses = latest["validate_mses"]
else:
    latest_iter = 0
    mses = []
    #lpips_losses = []
    validate_mses = []

# initialize iteration count
it = int(latest_iter)

remove loaded checkpoint

In [None]:
# remove loaded checkpoint
if latest is not None:
    del latest
    torch.cuda.empty_cache()

create folder for validation data

In [None]:
# create folder for validation data(会建在exp.ipynb同一个文件夹下)
if not os.path.exists("snapshots/validate/"):
    os.makedirs("snapshots/validate/")
val_dir = "snapshots/validate/"

run training

In [None]:
# run psf training
train_rgb_recon(
    deconv,
    opt,
    dataloader,
    devices,
    mses,
    num_iterations,
    lpips_weight=lpips_weight,
    lpips_step_milestones=lpips_step_milestones,
    lpips_step_size=lpips_step_size,
    checkpoint_interval = 20,
    snapshot_interval = 60,
    validate_mses=validate_mses,
    validate_args={
        "dataloader": val_dataloader,
        "devices": devices,
        "save_dir": val_dir,
    },
    it=it,
)

run testing

In [6]:
def test():
    # initialize model for training
    if os.path.exists("latest.pt"):
        latest = torch.load("snapshots/state2999.pt")  #TODO:change load state
    else:
        latest = None
    deconv = initialize_reconstruction(latest=latest)
    deconv.eval()
    print(deconv)
    num_params = sum([p.view(-1).shape[0] for p in deconv.parameters()])
    print(num_params)

    # initialize data
    dataset = create_dataset(test=True)
    dataloader = create_dataloader(dataset, test=True)

    # remove loaded checkpoint
    if latest is not None:
        del latest
        torch.cuda.empty_cache()

    # initialize results storage folder
    if not os.path.exists(f'./test'):
        os.mkdir(f'./test')
    save_dir = f'./test'

    losses=test_rgb_recon(deconv, dataloader, devices, save_dir=save_dir)

    return losses

test()

[info] loading from checkpoint
Sequential(
  (fourier_conv): FourierConv2D(3, 20, kernel_size=[270, 480])
  (fourier_relu): LeakyReLU(negative_slope=0.01)
  (fourier_bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2d_1): Conv2d(20, 64, kernel_size=(11, 11), stride=(1, 1), padding=[5, 5])
  (conv1_relu): LeakyReLU(negative_slope=0.01)
  (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2d_2): Conv2d(64, 64, kernel_size=(11, 11), stride=(1, 1), padding=[5, 5])
  (conv2_relu): LeakyReLU(negative_slope=0.01)
  (conv2_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2d_3): Conv2d(64, 3, kernel_size=(11, 11), stride=(1, 1), padding=[5, 5])
  (conv3_relu): ReLU()
)
16226175
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


  fourier_im = torch.fft(F.pad(im.unsqueeze(-1), pad_size), 2)
  real_feats = torch.ifft(fourier_feats, 2).index_select(-1, indices)


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
mse=0.012008828052785248
 lpips=-0.004603580688126385

