diff --git a/data/__init__.py b/data/__init__.py index e7448a0..3efc81b 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -51,7 +51,8 @@ def __getitem__(self, index): index = index % len(self.MASK) MASK, IMG = self.MASK[index], self.IMG[index] - MASK = color.rgb2gray(MASK) # shape of [h, w] + if len(MASK.shape) > 2: + MASK = color.rgb2gray(MASK) # shape of [h, w] NAME = (os.path.split(self.MASK_paths[index])[1]).split('.')[0] if len(IMG.shape) < 3: diff --git a/model/network_swin.py b/model/network.py similarity index 100% rename from model/network_swin.py rename to model/network.py