In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!unzip /content/drive/MyDrive/MayoDataset.zip

Archive:  /content/drive/MyDrive/MayoDataset.zip
  inflating: Mayo_s Dataset/test/C081/307.png  
  inflating: Mayo_s Dataset/test/C081/303.png  
  inflating: Mayo_s Dataset/test/C081/321.png  
  inflating: Mayo_s Dataset/test/C081/315.png  
  inflating: Mayo_s Dataset/test/C081/300.png  
  inflating: Mayo_s Dataset/test/C081/290.png  
  inflating: Mayo_s Dataset/test/C081/314.png  
  inflating: Mayo_s Dataset/test/C081/288.png  
  inflating: Mayo_s Dataset/test/C081/301.png  
  inflating: Mayo_s Dataset/test/C081/323.png  
  inflating: Mayo_s Dataset/test/C081/322.png  
  inflating: Mayo_s Dataset/test/C081/312.png  
  inflating: Mayo_s Dataset/test/C081/298.png  
  inflating: Mayo_s Dataset/test/C081/316.png  
  inflating: Mayo_s Dataset/test/C081/310.png  
  inflating: Mayo_s Dataset/test/C081/309.png  
  inflating: Mayo_s Dataset/test/C081/313.png  
  inflating: Mayo_s Dataset/test/C081/308.png  
  inflating: Mayo_s Dataset/test/C081/318.png  
  inflating: Mayo_s Dataset/test/C081/3

In [4]:
!mv Mayo_s\ Dataset MajoDataset

In [5]:
!pip install astra-toolbox torch torchvision numpy


Collecting astra-toolbox
  Downloading astra_toolbox-2.3.1-cp311-cp311-manylinux2014_x86_64.whl.metadata (2.4 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from astra-toolbox)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from astra-toolbox)
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
INFO: pip is looking at multiple versions of torch to determine which version is compatible with other requirements. This could take a while.
Collecting torch
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia

In [6]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import astra
import numpy as np


In [7]:
class AstraForward:
    def __init__(self, img_size=256, angles=None):
        if angles is None:
            # es. 60° totali => 0…π/3
            angles = np.linspace(0, np.pi/3, 90, dtype=np.float32)
        self.vol_geom = astra.create_vol_geom(img_size, img_size)
        self.proj_geom = astra.create_proj_geom(
            'parallel', 1.0, img_size, angles)
        self.fwd_id = astra.create_projector(
            'cuda', self.proj_geom, self.vol_geom)
    def forward(self, img):
        sinogram_id, sinogram = astra.create_sino(
            img.astype(np.float32), self.fwd_id)
        astra.data2d.delete(sinogram_id)
        return sinogram
    def adjoint(self, sino):
        rec_id = astra.data2d.create('-vol', self.vol_geom, sino)
        sinogram_id = astra.data2d.create('-sino', self.proj_geom, sino)
        cfg = astra.astra_dict('BP')
        cfg['SinoDataId'] = sinogram_id
        cfg['ReconstructionDataId'] = rec_id
        alg_id = astra.algorithm.create(cfg)
        astra.algorithm.run(alg_id)
        rec = astra.data2d.get(rec_id)
        # cleanup
        astra.algorithm.delete(alg_id)
        astra.data2d.delete(rec_id)
        astra.data2d.delete(sinogram_id)
        return rec


In [8]:
class ResidualUNet(nn.Module):
    def __init__(self, in_ch, out_ch, features=[64,128,256]):
        super().__init__()
        # Encoder
        self.downs = nn.ModuleList()
        self.ups   = nn.ModuleList()
        # costruisci blocchi encoder
        chs = [in_ch] + features
        for i in range(len(features)):
            self.downs.append(nn.Sequential(
                nn.Conv2d(chs[i], chs[i+1], 3, padding=1),
                nn.BatchNorm2d(chs[i+1]),
                nn.ReLU(inplace=True),
                nn.Conv2d(chs[i+1], chs[i+1], 3, padding=1),
                nn.BatchNorm2d(chs[i+1]),
                nn.ReLU(inplace=True)))
        # costruisci blocchi decoder
        rev_feats = features[::-1]
        for i in range(len(rev_feats)-1):
            self.ups.append(nn.Sequential(
                nn.Conv2d(rev_feats[i]*2, rev_feats[i+1], 3, padding=1),
                nn.BatchNorm2d(rev_feats[i+1]),
                nn.ReLU(inplace=True),
                nn.Conv2d(rev_feats[i+1], rev_feats[i+1], 3, padding=1),
                nn.BatchNorm2d(rev_feats[i+1]),
                nn.ReLU(inplace=True)))
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features[-1], features[-1]*2, 3, padding=1),
            nn.BatchNorm2d(features[-1]*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[-1]*2, features[-1], 3, padding=1),
            nn.BatchNorm2d(features[-1]),
            nn.ReLU(inplace=True))
        self.final = nn.Conv2d(features[0], out_ch, 1)
        self.pool = nn.MaxPool2d(2)
        self.upconv = nn.ModuleList([
            nn.ConvTranspose2d(f*2, f*2, 2, 2) for f in rev_feats
        ])
    def forward(self, x):
        skips = []
        for down in self.downs:
            x = down(x)
            skips.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        for i, up in enumerate(self.ups):
            x = self.upconv[i](x)
            skip = skips[-i-1]
            x = torch.cat((x, skip), dim=1)
            x = up(x)
        return self.final(x) + x  # residual


In [9]:
class LearnedPDNet(nn.Module):
    def __init__(self, fwd_op, Iters=10):
        super().__init__()
        self.I = Iters
        self.A = fwd_op
        # un solo canale di stato primale e duale
        self.dual_cnn   = nn.ModuleList([ResidualUNet(2,1) for _ in range(Iters)])
        self.primal_cnn = nn.ModuleList([ResidualUNet(1,1) for _ in range(Iters)])
        # passi di gradiente (tuned a mano o learnable)
        self.sigma = 1.0
        self.tau   = 1.0
        self.gamma = 0.5
    def forward(self, g):
        # g: [B, 1, H, W] sinogram (espanso in immagine 2D)
        f = torch.zeros_like(self.A.adjoint(g.cpu().numpy())).to(g.device)
        h = torch.zeros_like(g)
        fbar = f
        for i in range(self.I):
            # dual update
            Af = self.A.forward(fbar.cpu().numpy())
            Af = torch.from_numpy(Af).to(g.device).unsqueeze(1)
            h = self.dual_cnn[i](torch.cat([h, Af-g], dim=1))
            # primal update
            Ah = self.A.adjoint(h.cpu().numpy())
            Ah = torch.from_numpy(Ah).to(g.device).unsqueeze(1)
            f = self.primal_cnn[i](f - self.tau*Ah)
            # over-relaxation
            fbar = f + self.gamma*(f - fbar)
        return f


In [10]:
class MajoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.img_files = sorted(os.listdir(os.path.join(root_dir, 'train')))
        self.sino_dir  = os.path.join(root_dir, 'train')
        self.transform = transform
    def __len__(self):
        return len(self.img_files)
    def __getitem__(self, idx):
        img = np.load(os.path.join(self.sino_dir, self.img_files[idx]))
        sino = forward_op.forward(img)  # numpy array
        img = torch.from_numpy(img).float().unsqueeze(0)
        sino = torch.from_numpy(sino).float().unsqueeze(0)
        return sino, img

# setup
forward_op = AstraForward(img_size=256)
model = LearnedPDNet(forward_op).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

train_ds = MajoDataset('/content/MajoDataset')
loader  = DataLoader(train_ds, batch_size=2, shuffle=True)

# training
for epoch in range(100):
    model.train()
    total_loss = 0
    for sino, gt in loader:
        sino, gt = sino.cuda(), gt.cuda()
        recon = model(sino)
        loss = criterion(recon, gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}: Loss = {total_loss/len(loader):.4f}")


IsADirectoryError: [Errno 21] Is a directory: '/content/MajoDataset/train/C027'