In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import torch.cuda.amp as amp
import os
from Mydataset import RDataset
from modle import ResNet1, ResNet2, ResNet3, ResNet4, GAPFC, MWBlock1, MWBlock2, MWBlock3
from centerloss import CenterLoss

# Define the model
class mmfCNN(nn.Module):
    def __init__(self):
        super(mmfCNN, self).__init__()
        self.Rfeature1 = ResNet1()
        self.Rfeature2 = ResNet2()
        self.Rfeature3 = ResNet3()
        self.TOTALfeature1 = ResNet4()
        self.classifier = GAPFC(1)

    def extract_features(self, x1):
        x1 = self.Rfeature1(x1)
        x1 = self.Rfeature2(x1)
        x1 = self.Rfeature3(x1)
        features = self.TOTALfeature1(x1)
        return features

    def forward(self, x1):
        features = self.extract_features(x1)
        output = self.classifier(features)
        return output, features

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

# Initialize model, loss, and optimizer
RDI = mmfCNN().to(device)
bce_loss = nn.BCEWithLogitsLoss()
center_loss = CenterLoss(num_classes=2, feat_dim=512, use_gpu=device.type == 'cuda')
if device.type == 'cuda':
    center_loss.centers.data = center_loss.centers.data.to(device)

optimizer = optim.SGD(RDI.parameters(), lr=0.00025, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

# Data preprocessing and loading
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

R_root = './data/RGB'

dataset_train = RDataset(R_root, transform, step='train')

num_workers = min(4, os.cpu_count() // 2)
dataload_train = DataLoader(dataset_train, batch_size=256, shuffle=True, drop_last=True, pin_memory=True, prefetch_factor=3, num_workers=num_workers)

# TensorBoard
writer = SummaryWriter(log_dir='tf-logs')

# Mixed precision training
scaler = amp.GradScaler()

# Training loop
print(f"----- Training started -----")
Epoch = 1
for epoch in range(Epoch):
    RDI.train()
    total_train_step = 0
    for batch_idx, (data1, label) in enumerate(dataload_train):
        data1, label = data1.to(device), label.to(device)

        with amp.autocast():
            outputs, features = RDI(data1)
            outputs = torch.squeeze(outputs)

            # Compute losses
            bce_loss_value = bce_loss(outputs.float(), label.float())
            center_loss_value = center_loss(features, label)
            loss = bce_loss_value + 0.001 * center_loss_value
            
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_train_step += 1
        scheduler.step()

        with torch.no_grad():
            mask = torch.sigmoid(outputs).ge(0.5).float()
            correct = (mask == label).sum().item()
            acc = correct / label.shape[0]
            writer.add_scalar('Training loss', loss.item(), total_train_step)
            writer.add_scalar('Training accuracy', acc, total_train_step)

# Save the model
model_path = './s_trained_model.pth'
torch.save(RDI.state_dict(), model_path)
print(f"Model saved to {model_path}")

# Close TensorBoard writer
writer.close()

torch.cuda.empty_cache()

----- Training started -----
Model saved to ./s_trained_model.pth
