In [None]:
import torch
from torchvision.models import efficientnet_b3, EfficientNet_B3_Weights
from torchvision import transforms
import torchvision
import torchxrayvision as xrv
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.io import read_image
import os

In [None]:
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url

def get_state_dict(self, *args, **kwargs):
    kwargs.pop("check_hash")
    return load_state_dict_from_url(self.url, *args, **kwargs)
WeightsEnum.get_state_dict = get_state_dict
effNet = efficientnet_b3(weights = "DEFAULT")

Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-cf984f9c.pth
100%|██████████| 47.2M/47.2M [00:00<00:00, 114MB/s]


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class LungDataset(torch.utils.data.Dataset):
  def __init__(self, img_paths, transforms):
    self.img_paths = img_paths
    self.transforms = transforms

  def __getitem__(self, idx):
    img = read_image(self.img_paths[idx])
    img = torch.tensor(img)
    if self.transforms is not None:
      img = self.transforms(img)
    return img, self.img_paths[idx]

  def __len__(self):
    return len(self.img_paths)


In [None]:
class JustGiveLungs(object):
  def __init__(self, n = 5):
    self.n = n
    self.seg_model = xrv.baseline_models.chestx_det.PSPNet().to(device)

  def __call__(self, img):

    mask = self.getMasks(img)
    return mask

  def getMasks(self, img_org):
    img_org = torch.moveaxis(img_org, 0, -1)
    img1 = xrv.datasets.normalize(np.asarray(img_org), 255) # convert 8-bit image to [-1024, 1024] range
    img1 = img1.mean(2)[None, ...] # Make single color channel
    transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512, engine='cv2')])
    img1 = transform(img1)
    img = torch.from_numpy(img1)

    with torch.no_grad():
      pred = self.seg_model(img.to(device))

    pred = torch.sigmoid(pred)  # sigmoid
    pred[pred < 0.5] = 0
    pred[pred > 0.5] = 1

    mask1 = self.pad_tensor_with_ones(pred[0, 4], self.n)
    mask2 = self.pad_tensor_with_ones(pred[0, 5], self.n)

    image_tensor = img_org.float()

    # Resize masks to match the size of the input image
    resize_transform = transforms.Compose([
        transforms.ToPILImage(),
        # transforms.Resize((img_org.shape[0], img_org.shape[1])),
        transforms.ToTensor()
    ])

    mask1_resized = resize_transform(mask1)
    mask2_resized = resize_transform(mask2)

    # Create a combined mask where either mask1 or mask2 is 1
    combined_mask = torch.max(mask1_resized, mask2_resized)

    # Apply masks
    combined_masked_image = image_tensor * torch.moveaxis(combined_mask, 0, -1)

    nonzero_rows, nonzero_cols, _ = np.nonzero(combined_masked_image.numpy())
    min_row, max_row = np.min(nonzero_rows), np.max(nonzero_rows)
    min_col, max_col = np.min(nonzero_cols), np.max(nonzero_cols)
    combined_masked_image_np = combined_masked_image[min_row:max_row + 1, min_col:max_col + 1, :]

    return torch.moveaxis(combined_masked_image_np, 2, 0)

  def pad_tensor_with_ones(self, input_tensor, n):
    # Assuming input_tensor is a 2D tensor of shape [512, 512]
    tensor = input_tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
    kernel = torch.ones(1, 1, n, n, device=tensor.device)
    padded_tensor = F.conv2d(tensor, kernel, padding=(n-1)//2)

    padded_tensor = padded_tensor.squeeze(0).squeeze(0)
    padded_tensor = (padded_tensor > 0).float()  # Convert non-zero values to 1

    return padded_tensor


In [None]:
def find_imgs(directory_path):
    imgs = []
    image_extensions = [".jpg", ".jpeg", ".png", ".gif", ".bmp"]  # Add more extensions as needed

    def is_image(filename):
        return any(filename.lower().endswith(extension) for extension in image_extensions)

    for root, _, files in os.walk(directory_path):
        for file in files:
            if is_image(file):
                image_path = os.path.join(root, file)
                imgs.append(image_path)
    return imgs

In [None]:
class JustGiveLungsBatched(object):
  def __init__(self, n = 5):
    self.n = n
    self.seg_model = xrv.baseline_models.chestx_det.PSPNet().to(device)

  def __call__(self, img):

    mask = self.getMasks(img)
    return mask

  def getMasks(self, img_orgs):

    img_orgs = torch.moveaxis(img_orgs, 1, -1)
    img1 = xrv.datasets.normalize(np.asarray(img_orgs), 255) # convert 8-bit image to [-1024, 1024] range
    img1 = img1.mean(3)[None, ...] # Make single color channel
    transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512, engine='cv2')])
    img1 = transform(img1)
    imgs.append((img1))

    img = torch.from_numpy(np.asarray(imgs))
    with torch.no_grad():
      pred = self.seg_model(img.to(device))

    pred = torch.sigmoid(pred)  # sigmoid
    pred[pred < 0.5] = 0
    pred[pred > 0.5] = 1

    mask1 = self.pad_tensor_with_ones(pred[:, 4], self.n)
    mask2 = self.pad_tensor_with_ones(pred[:, 5], self.n)

    image_tensor = img_orgs.float()

    # Resize masks to match the size of the input image
    resize_transform = transforms.Compose([
        transforms.ToPILImage(),
        # transforms.Resize((img_org.shape[0], img_org.shape[1])),
        transforms.ToTensor()
    ])

    mask1_resized = resize_transform(mask1)
    mask2_resized = resize_transform(mask2)

    # Create a combined mask where either mask1 or mask2 is 1
    combined_mask = torch.max(mask1_resized, mask2_resized)

    # Apply masks
    combined_masked_image = image_tensor * torch.moveaxis(combined_mask, 0, -1)

    nonzero_rows, nonzero_cols, _ = np.nonzero(combined_masked_image.numpy())
    min_row, max_row = np.min(nonzero_rows), np.max(nonzero_rows)
    min_col, max_col = np.min(nonzero_cols), np.max(nonzero_cols)
    combined_masked_image_np = combined_masked_image[min_row:max_row + 1, min_col:max_col + 1, :]

    return torch.moveaxis(combined_masked_image_np, 2, 0)

  def pad_tensor_with_ones(self, input_tensor, n):
    # Assuming input_tensor is a 2D tensor of shape [512, 512]
    tensor = input_tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions
    kernel = torch.ones(1, 1, n, n, device=tensor.device)
    padded_tensor = F.conv2d(tensor, kernel, padding=(n-1)//2)

    padded_tensor = padded_tensor.squeeze(0).squeeze(0)
    padded_tensor = (padded_tensor > 0).float()  # Convert non-zero values to 1

    return padded_tensor


In [None]:
np.random.seed(10)
cheXpert_imgs = find_imgs('/content/CheXpert-v1.0-small')
cheXpert_s_imgs = list(np.random.choice(cheXpert_imgs, 15000))
img_paths = cheXpert_s_imgs + find_imgs('/content/train')

In [None]:
resize = transforms.Compose([
    transforms.Resize([512, 512])
])

In [None]:
lung_ds = LungDataset(img_paths, resize)
lung_dl = torch.utils.data.DataLoader(lung_ds, batch_size=32, shuffle = True)

In [None]:
from IPython.display import clear_output
paths = []
save_path = '/gdrive/MyDrive/JustLungs/'

for idx,(img, path) in enumerate(lung_dl):
  paths.append(path)
  torchvision.utils.save_image(img[0] / 255, save_path+str(idx)+'.png')
  clear_output(wait=True)
  print(idx)

In [None]:
class JustGiveLungsBatched(object):
  def __init__(self, n = 5):
    self.n = n
    self.seg_model = xrv.baseline_models.chestx_det.PSPNet().to(device)

  def __call__(self, img):

    mask = self.getMasks(img)
    return mask

  def getMasks(self, img_orgs):

    #img_orgs = torch.moveaxis(img_orgs, 1, -1)
    img1 = (2 * (img_orgs / 255.0) - 1.0) * 1024 # convert 8-bit image to [-1024, 1024] range
    imgs = img1.mean(1)[:,None, ...] # Make single color channel
    print(imgs.shape)
    with torch.no_grad():
      pred = self.seg_model(imgs)

    pred = torch.sigmoid(pred)  # sigmoid
    pred[pred < 0.5] = 0
    pred[pred > 0.5] = 1

    mask1 = self.pad_tensor_with_ones(pred[:, 4], self.n)
    mask2 = self.pad_tensor_with_ones(pred[:, 5], self.n)

    image_tensor = img_orgs.float()

    mask1_resized = (mask1)
    mask2_resized = (mask2)

    # Create a combined mask where either mask1 or mask2 is 1
    combined_mask = torch.max(mask1_resized, mask2_resized)

    # Apply masks
    combined_masked_image = image_tensor * combined_mask

    return combined_masked_image

  def pad_tensor_with_ones(self, input_tensor, n):
    # Assuming input_tensor is a 2D tensor of shape [512, 512]
    tensor = input_tensor.unsqueeze(1)  # Add batch and channel dimensions
    kernel = torch.ones(1, 1, n, n, device=tensor.device)
    padded_tensor = F.conv2d(tensor, kernel, padding=(n-1)//2)

    padded_tensor = padded_tensor.squeeze(0).squeeze(0)
    padded_tensor = (padded_tensor > 0).float()  # Convert non-zero values to 1

    return padded_tensor


In [None]:
from IPython.display import clear_output
paths = []
save_path = '/gdrive/MyDrive/JustLungs/'

idx = 0
for (imgs, path) in lung_dl:
  paths = paths + list(path)
  imgs = imgs.cpu().cuda()
  out = gib.getMasks(imgs)
  for i in range(out.shape[0]):
    torchvision.utils.save_image(out[i] / 255, save_path+str(idx)+'.png')
    idx += 1
  clear_output(wait=True)
  print(idx / 30000.0 )

1.0
